Spaces:
Sleeping
Sleeping
# Importing necessary libraries | |
import streamlit as st | |
st.set_page_config( | |
page_title="Scenario Planner", | |
page_icon="⚖️", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Disable +/- for number input | |
st.markdown( | |
""" | |
<style> | |
button.step-up {display: none;} | |
button.step-down {display: none;} | |
div[data-baseweb] {border-radius: 4px;} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
import re | |
import sys | |
import copy | |
import pickle | |
import traceback | |
import numpy as np | |
import pandas as pd | |
from scenario import numerize | |
import plotly.graph_objects as go | |
from post_gres_cred import db_cred | |
from scipy.optimize import minimize | |
from log_application import log_message | |
from utilities import project_selection, update_db, set_header, load_local_css | |
from utilities import ( | |
get_panels_names, | |
get_metrics_names, | |
name_formating, | |
load_rcs_metadata_files, | |
load_scenario_metadata_files, | |
generate_rcs_data, | |
generate_scenario_data, | |
) | |
from constants import ( | |
xtol_tolerance_per, | |
mroi_threshold, | |
word_length_limit_lower, | |
word_length_limit_upper, | |
) | |
schema = db_cred["schema"] | |
load_local_css("styles.css") | |
set_header() | |
# Initialize project name session state | |
if "project_name" not in st.session_state: | |
st.session_state["project_name"] = None | |
# Fetch project dictionary | |
if "project_dct" not in st.session_state: | |
project_selection() | |
st.stop() | |
# Display Username and Project Name | |
if "username" in st.session_state and st.session_state["username"] is not None: | |
cols1 = st.columns([2, 1]) | |
with cols1[0]: | |
st.markdown(f"**Welcome {st.session_state['username']}**") | |
with cols1[1]: | |
st.markdown(f"**Current Project: {st.session_state['project_name']}**") | |
# Initialize ROI threshold | |
if "roi_threshold" not in st.session_state: | |
st.session_state.roi_threshold = 1 | |
# Initialize message display holder | |
if "message_display" not in st.session_state: | |
st.session_state.message_display = {"type": "success", "message": None, "icon": ""} | |
# Function to reset modified_scenario_data | |
def reset_scenario(metrics_selected=None, panel_selected=None): | |
# Clear message_display | |
st.session_state.message_display = {"type": "success", "message": None, "icon": ""} | |
# Use default values from session state if not provided | |
if metrics_selected is None: | |
metrics_selected = st.session_state["response_metrics_selectbox_sp"] | |
if panel_selected is None: | |
panel_selected = st.session_state["panel_selected_selectbox_sp"] | |
# Load original scenario data | |
original_data = st.session_state["project_dct"]["scenario_planner"][ | |
"original_metadata_file" | |
] | |
original_scenario_data = original_data[metrics_selected][panel_selected] | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Update the specific section with the original scenario data | |
data[metrics_selected][panel_selected] = copy.deepcopy(original_scenario_data) | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Function to build s curve | |
def s_curve(x, power, K, b, a, x0): | |
return K / (1 + b * np.exp(-a * ((x / 10**power) - x0))) | |
# Function to retrieve S-curve parameters for a given metric, panel, and channel | |
def get_s_curve_params( | |
metrics_selected, | |
panel_selected, | |
channel_selected, | |
original_rcs_data, | |
modified_rcs_data, | |
): | |
# Retrieve 'power' parameter from the original data for the specific metric, panel, and channel | |
power = original_rcs_data[metrics_selected][panel_selected][channel_selected][ | |
"power" | |
] | |
# Get the S-curve parameters from the modified data for the same metric, panel, and channel | |
s_curve_param = modified_rcs_data[metrics_selected][panel_selected][ | |
channel_selected | |
] | |
# Load modified scenario metadata | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Update modified S-curve parameters | |
data[metrics_selected][panel_selected]["channels"][channel_selected][ | |
"response_curve_params" | |
] = s_curve_param | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Update the 'power' parameter in the modified S-curve parameters with the original 'power' value | |
s_curve_param["power"] = power | |
# Return the updated S-curve parameters | |
return s_curve_param | |
# Function to calculate total contribution | |
def get_total_contribution( | |
spends, channels, s_curve_params, channels_proportion, modified_scenario_data | |
): | |
total_contribution = 0 | |
for i in range(len(channels)): | |
channel_name = channels[i] | |
channel_s_curve_params = s_curve_params[channel_name] | |
spend_proportion = spends[i] * channels_proportion[channel_name] | |
total_contribution += sum( | |
s_curve( | |
spend_proportion, | |
channel_s_curve_params["power"], | |
channel_s_curve_params["K"], | |
channel_s_curve_params["b"], | |
channel_s_curve_params["a"], | |
channel_s_curve_params["x0"], | |
) | |
) + sum( | |
modified_scenario_data["channels"][channel_name]["correction"] | |
) # correction for s-curve | |
return total_contribution + sum(modified_scenario_data["constant"]) | |
# Function to calculate total spends | |
def get_total_spends(spends, channels_conversion_ratio): | |
return np.sum(spends * np.array(list(channels_conversion_ratio.values()))) | |
# Function to optimizes spends for all channels given bounds and a total spend target | |
def optimizer( | |
optimization_goal, | |
s_curve_params, | |
channels_spends, | |
channels_proportion, | |
channels_conversion_ratio, | |
total_target, | |
bounds_dict, | |
modified_scenario_data, | |
): | |
# Extract channel names and corresponding actual spends | |
channels = list(channels_spends.keys()) | |
actual_spends = np.array(list(channels_spends.values())) | |
num_channels = len(actual_spends) | |
# Define the objective function based on the optimization goal | |
def objective_fun(spends): | |
if optimization_goal == "Spend": | |
# Minimize negative total contribution to maximize the total contribution | |
return -get_total_contribution( | |
spends, | |
channels, | |
s_curve_params, | |
channels_proportion, | |
modified_scenario_data, | |
) | |
else: | |
# Minimize total spends | |
return get_total_spends(spends, channels_conversion_ratio) | |
def constraint_fun(spends): | |
if optimization_goal == "Spend": | |
# Ensure the total spends equals the total spend target | |
return get_total_spends(spends, channels_conversion_ratio) | |
else: | |
# Ensure the total contribution equals the total contribution target | |
return get_total_contribution( | |
spends, | |
channels, | |
s_curve_params, | |
channels_proportion, | |
modified_scenario_data, | |
) | |
# Equality constraint | |
constraints = { | |
"type": "eq", | |
"fun": lambda spends: constraint_fun(spends) - total_target, | |
} # Sum of all channel spends/metrics should equal the total spend/metrics target | |
# Bounds for each channel's spend based | |
bounds = [ | |
( | |
actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100), | |
actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100), | |
) | |
for i in range(num_channels) | |
] | |
# Initial guess for the optimization | |
initial_guess = np.array(actual_spends) | |
# Calculate xtol as n% of the minimum of spends | |
xtol = max(10, (xtol_tolerance_per / 100) * np.min(actual_spends)) | |
# Perform the optimization using 'trust-constr' method | |
result = minimize( | |
objective_fun, | |
initial_guess, | |
method="trust-constr", | |
constraints=constraints, | |
bounds=bounds, | |
options={ | |
"disp": True, # Display the optimization process | |
"xtol": xtol, # Dynamic step size tolerance | |
"maxiter": 1e5, # Maximum number of iterations | |
}, | |
) | |
# Extract the optimized spends from the result | |
optimized_spends_array = result.x | |
# Convert optimized spends back to a dictionary with channel names | |
optimized_spends = { | |
channels[i]: max(0, optimized_spends_array[i]) for i in range(num_channels) | |
} | |
return optimized_spends, result.success | |
# Function to calculate achievable targets at lower and upper spend bounds | |
def max_target_achievable( | |
channels_spends, | |
s_curve_params, | |
channels_proportion, | |
modified_scenario_data, | |
bounds_dict, | |
): | |
# Extract channel names and corresponding actual spends | |
channels = list(channels_spends.keys()) | |
actual_spends = np.array(list(channels_spends.values())) | |
num_channels = len(actual_spends) | |
# Bounds for each channel's spend | |
lower_spends, upper_spends = [], [] | |
for i in range(num_channels): | |
lower_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100)) | |
upper_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100)) | |
# Calculate achievable targets at lower and upper spend bounds | |
lower_achievable_target = get_total_contribution( | |
lower_spends, | |
channels, | |
s_curve_params, | |
channels_proportion, | |
modified_scenario_data, | |
) | |
upper_achievable_target = get_total_contribution( | |
upper_spends, | |
channels, | |
s_curve_params, | |
channels_proportion, | |
modified_scenario_data, | |
) | |
# Return achievable targets with ±0.1% safety margin | |
return max(0, 1.001 * lower_achievable_target), 0.999 * upper_achievable_target | |
# Function to check if number is in valid format | |
def is_valid_number_format(number_str): | |
# Check for None | |
if number_str is None: | |
# Store the message details in session state for invalid input | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
# Define the valid suffixes | |
valid_suffixes = {"K", "M", "B", "T"} | |
# Check for negative numbers | |
if number_str[0] == "-": | |
# Store the message details in session state for invalid input | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
# Check if the string ends with a digit | |
if number_str[-1].isdigit(): | |
try: | |
# Attempt to convert the entire string to float | |
number = float(number_str) | |
# Ensure the number is non-negative | |
if number >= 0: | |
return True | |
else: | |
# Store the message details in session state for invalid input | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
except ValueError: | |
# Store the message details in session state for invalid input | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
# Check if the string ends with a valid suffix | |
suffix = number_str[-1].upper() | |
if suffix in valid_suffixes: | |
num_part = number_str[:-1] # Extract the numerical part | |
try: | |
# Attempt to convert the numerical part to float | |
number = float(num_part) | |
# Ensure the number part is non-negative | |
if number >= 0: | |
return True | |
else: | |
# Store the message details in session state for invalid input | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
except ValueError: | |
# Store the message details in session state for invalid input | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
# If neither condition is met, return False | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Invalid input: Please enter a valid number.", | |
"icon": "⚠️", | |
} | |
return False | |
# Function to converts a string with number suffixes (K, M, B, T) to a float | |
def convert_to_float(number_str): | |
# Dictionary mapping suffixes to their multipliers | |
multipliers = { | |
"K": 1e3, # Thousand | |
"M": 1e6, # Million | |
"B": 1e9, # Billion | |
"T": 1e12, # Trillion | |
} | |
# If there's no suffix, directly convert to float | |
if number_str[-1].isdigit(): | |
return float(number_str) | |
# Extract the suffix (last character) and the numerical part | |
suffix = number_str[-1].upper() | |
num_part = number_str[:-1] | |
# Convert the numerical part to float and multiply by the corresponding multiplier | |
return float(num_part) * multipliers[suffix] | |
# Function to update absolute_channel_spends change | |
def absolute_channel_spends_change( | |
channel_key, channel_spends_actual, channel, metrics_selected, panel_selected | |
): | |
# Do not update if the number is in an invalid format | |
if not is_valid_number_format(st.session_state[f"{channel_key}_abs_spends_key"]): | |
return | |
# Get updated absolute spends from session state | |
new_absolute_spends = ( | |
convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"]) | |
* st.session_state["multiplier"] | |
) | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Total channel spends | |
total_channel_spends = 0 | |
for current_channel in list( | |
data[metrics_selected][panel_selected]["channels"].keys() | |
): | |
# Channel key | |
channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}" | |
total_channel_spends += ( | |
convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"]) | |
* st.session_state["multiplier"] | |
) | |
# Check if total channel spends are within the allowed range (±50% of the original total spends) | |
if ( | |
total_channel_spends | |
< 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"] | |
and total_channel_spends | |
> 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"] | |
): | |
# Update the modified_total_spends for the specified channel | |
data[metrics_selected][panel_selected]["channels"][channel][ | |
"modified_total_spends" | |
] = new_absolute_spends / float( | |
data[metrics_selected][panel_selected]["channels"][channel][ | |
"conversion_rate" | |
] | |
) | |
# Update total spends | |
data[metrics_selected][panel_selected][ | |
"modified_total_spends" | |
] = total_channel_spends | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] = data | |
else: | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Keep total spending within ±50% of the original value.", | |
"icon": "⚠️", | |
} | |
# Function to update percentage_channel_spends change | |
def percentage_channel_spends_change( | |
channel_key, channel_spends_actual, channel, metrics_selected, panel_selected | |
): | |
# Retrieve the percentage spend change from session state | |
percentage_channel_spends = round( | |
st.session_state[f"{channel_key}_per_spends_key"], 0 | |
) | |
# Calculate the new absolute spends | |
new_absolute_spends = channel_spends_actual * (1 + percentage_channel_spends / 100) | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Total channel spends | |
total_channel_spends = 0 | |
for current_channel in list( | |
data[metrics_selected][panel_selected]["channels"].keys() | |
): | |
# Channel key | |
channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}" | |
# Current channel spends actual | |
current_channel_spends_actual = data[metrics_selected][panel_selected][ | |
"channels" | |
][current_channel]["actual_total_spends"] | |
# Current channel conversion rate | |
current_channel_conversion_rate = data[metrics_selected][panel_selected][ | |
"channels" | |
][current_channel]["conversion_rate"] | |
# Calculate the current channel absolute spends | |
current_channel_absolute_spends = ( | |
current_channel_spends_actual | |
* current_channel_conversion_rate | |
* (1 + st.session_state[f"{channel_key}_per_spends_key"] / 100) | |
) | |
total_channel_spends += current_channel_absolute_spends | |
# Check if total channel spends are within the allowed range (±50% of the original total spends) | |
if ( | |
total_channel_spends | |
< 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"] | |
and total_channel_spends | |
> 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"] | |
): | |
# Update the modified_total_spends for the specified channel | |
data[metrics_selected][panel_selected]["channels"][channel][ | |
"modified_total_spends" | |
] = float(new_absolute_spends) / float( | |
data[metrics_selected][panel_selected]["channels"][channel][ | |
"conversion_rate" | |
] | |
) | |
# Update total spends | |
data[metrics_selected][panel_selected][ | |
"modified_total_spends" | |
] = total_channel_spends | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] = data | |
else: | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Keep total spending within ±50% of the original value.", | |
"icon": "⚠️", | |
} | |
# # Function to update total input change | |
# def total_input_change(per_change): | |
# # Load modified scenario data | |
# data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# # Get the list of all channels in the specified panel and metric | |
# channel_list = list(data[metrics_selected][panel_selected]["channels"].keys()) | |
# # Iterate over each channel to update their modified spends | |
# for channel in channel_list: | |
# # Retrieve the actual spends for the channel | |
# channel_actual_spends = data[metrics_selected][panel_selected]["channels"][ | |
# channel | |
# ]["actual_total_spends"] | |
# # Calculate the modified spends for the channel based on the percent change | |
# modified_channel_metrics = channel_actual_spends * ((100 + per_change) / 100) | |
# # Update the channel's modified total spends in the data | |
# data[metrics_selected][panel_selected]["channels"][channel][ | |
# "modified_total_spends" | |
# ] = modified_channel_metrics | |
# # Update modified scenario metadata | |
# st.session_state["project_dct"]["scenario_planner"][ | |
# "modified_metadata_file" | |
# ] = data | |
# Function to update total input change | |
def total_input_change(per_change, metrics_selected, panel_selected): | |
# Load modified and original scenario data | |
modified_data = st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
].copy() | |
original_data = st.session_state["project_dct"]["scenario_planner"][ | |
"original_metadata_file" | |
].copy() | |
# Get the list of all channels in the selected panel and metric | |
channel_list = list( | |
modified_data[metrics_selected][panel_selected]["channels"].keys() | |
) | |
# Separate channels into unfrozen and frozen based on optimization status | |
unfrozen_channels, frozen_channels = [], [] | |
for channel in channel_list: | |
channel_key = f"{metrics_selected}_{panel_selected}_{channel}" | |
if st.session_state.get(f"{channel_key}_allow_optimize_key", False): | |
frozen_channels.append(channel) | |
else: | |
unfrozen_channels.append(channel) | |
# Calculate spends and total share from frozen channels, weighted by conversion rate | |
frozen_channel_share, frozen_channel_spends = 0, 0 | |
for channel in frozen_channels: | |
conversion_rate = original_data[metrics_selected][panel_selected]["channels"][ | |
channel | |
]["conversion_rate"] | |
actual_spends = original_data[metrics_selected][panel_selected]["channels"][ | |
channel | |
]["actual_total_spends"] | |
modified_spends = modified_data[metrics_selected][panel_selected]["channels"][ | |
channel | |
]["modified_total_spends"] | |
spends_diff = max(actual_spends, 1e-3) * ((100 + per_change) / 100) - max( | |
modified_spends, 1e-3 | |
) | |
frozen_channel_share += spends_diff * conversion_rate | |
frozen_channel_spends += max(actual_spends, 1e-3) * conversion_rate | |
# Redistribute frozen share across unfrozen channels based on original spend weights | |
for channel in unfrozen_channels: | |
conversion_rate = original_data[metrics_selected][panel_selected]["channels"][ | |
channel | |
]["conversion_rate"] | |
actual_spends = original_data[metrics_selected][panel_selected]["channels"][ | |
channel | |
]["actual_total_spends"] | |
# Calculate weight of the current channel's original spends | |
total_original_spends = original_data[metrics_selected][panel_selected][ | |
"actual_total_spends" | |
] | |
channel_weight = (actual_spends * conversion_rate) / ( | |
total_original_spends - frozen_channel_spends | |
) | |
# Calculate the modified spends with redistributed frozen share | |
modified_spends = ( | |
max(actual_spends, 1e-3) * ((100 + per_change) / 100) | |
+ (frozen_channel_share * channel_weight) / conversion_rate | |
) | |
# Update modified total spends in the modified data | |
modified_data[metrics_selected][panel_selected]["channels"][channel][ | |
"modified_total_spends" | |
] = modified_spends | |
# Save the updated modified scenario data back to the session state | |
st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] = modified_data | |
# Function to update total_absolute_main_key change | |
def total_absolute_main_key_change(metrics_selected, panel_selected, optimization_goal): | |
# Do not update if the number is in an invalid format | |
if not is_valid_number_format(st.session_state["total_absolute_main_key"]): | |
return | |
# Get updated absolute from session state | |
new_absolute = ( | |
convert_to_float(st.session_state["total_absolute_main_key"]) | |
* st.session_state["multiplier"] | |
) | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
if optimization_goal == "Spend": | |
# Retrieve the old absolute spends | |
old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"] | |
else: | |
# Retrieve the old absolute metrics | |
old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"] | |
# Calculate the allowable range for new spends | |
lower_bound = old_absolute * 0.5 | |
upper_bound = old_absolute * 1.5 | |
# Ensure the new spends are within ±50% of the old value | |
if new_absolute < lower_bound or new_absolute > upper_bound: | |
new_absolute = old_absolute | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Keep total spending within ±50% of the original value.", | |
"icon": "⚠️", | |
} | |
if optimization_goal == "Spend": | |
# Update the modified_total_spends with the constrained value | |
data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute | |
else: | |
# Update the modified_total_sales with the constrained value | |
data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Update total input change | |
if optimization_goal == "Spend": | |
per_change = ((new_absolute - old_absolute) / old_absolute) * 100 | |
total_input_change(per_change, metrics_selected, panel_selected) | |
# Function to update total_absolute_key change | |
def total_absolute_key_change(metrics_selected, panel_selected, optimization_goal): | |
# Get updated absolute from session state | |
new_absolute = ( | |
convert_to_float(st.session_state["total_absolute_key"]) | |
* st.session_state["multiplier"] | |
) | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
if optimization_goal == "Spend": | |
# Update the modified_total_spends for the specified channel | |
data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute | |
old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"] | |
else: | |
# Update the modified_total_sales for the specified channel | |
data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute | |
old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"] | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Update total input change | |
if optimization_goal == "Spend": | |
per_change = ((new_absolute - old_absolute) / old_absolute) * 100 | |
total_input_change(per_change, metrics_selected, panel_selected) | |
# Function to update total_absolute_key change | |
def total_percentage_key_change( | |
metrics_selected, | |
panel_selected, | |
absolute_value, | |
optimization_goal, | |
): | |
# Get updated absolute from session state | |
new_absolute = absolute_value * (1 + st.session_state["total_percentage_key"] / 100) | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
if optimization_goal == "Spend": | |
# Update the modified_total_spends for the specified channel | |
data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute | |
old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"] | |
else: | |
# Update the modified_total_sales for the specified channel | |
data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute | |
old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"] | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Update total input change | |
if optimization_goal == "Spend": | |
per_change = ((new_absolute - old_absolute) / old_absolute) * 100 | |
total_input_change(per_change, metrics_selected, panel_selected) | |
# Function to update bound change | |
def bound_change(metrics_selected, panel_selected, channel_key, channel): | |
# Get updated bounds from session state | |
new_lower_bound = st.session_state[f"{channel_key}_lower_key"] | |
new_upper_bound = st.session_state[f"{channel_key}_upper_key"] | |
if new_lower_bound > new_upper_bound: | |
new_bounds = [-10, 10] | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Lower bound cannot be greater than Upper bound.", | |
"icon": "⚠️", | |
} | |
else: | |
new_bounds = [new_lower_bound, new_upper_bound] | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Update the bounds for the specified channel | |
data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Function to update freeze change | |
def freeze_change(metrics_selected, panel_selected, channel_key, channel, channel_list): | |
# Initialize counter for channels that are not frozen | |
unfrozen_channel_count = 0 | |
# Check the optimization status of each channel | |
for current_channel in channel_list: | |
current_channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}" | |
unfrozen_channel_count += ( | |
1 | |
if not st.session_state[f"{current_channel_key}_allow_optimize_key"] | |
else 0 | |
) | |
# Ensure at least two channels are allowed for optimization | |
if unfrozen_channel_count < 2: | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Please allow at least two channels to be optimized.", | |
"icon": "⚠️", | |
} | |
return | |
if st.session_state[f"{channel_key}_allow_optimize_key"]: | |
# Updated bounds from session state | |
new_lower_bound, new_upper_bound = 0, 0 | |
new_bounds = [new_lower_bound, new_upper_bound] | |
new_freeze = True | |
else: | |
# Updated bounds from session state | |
new_lower_bound, new_upper_bound = -10, 10 | |
new_bounds = [new_lower_bound, new_upper_bound] | |
new_freeze = False | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Update the bounds for the specified channel | |
data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds | |
data[metrics_selected][panel_selected]["channels"][channel]["freeze"] = new_freeze | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Function to calculate y, ROI and MROI for given point | |
def get_point_parms( | |
x_val, | |
current_s_curve_params, | |
current_channel_proportion, | |
current_conversion_rate, | |
channel_correction, | |
): | |
# Calculate y value for the given spend point | |
y_val = ( | |
sum( | |
s_curve( | |
(x_val * current_channel_proportion), | |
current_s_curve_params["power"], | |
current_s_curve_params["K"], | |
current_s_curve_params["b"], | |
current_s_curve_params["a"], | |
current_s_curve_params["x0"], | |
) | |
) | |
+ channel_correction | |
) | |
# Calculate MROI using a small nudge for actual spends | |
nudge = 1e-3 | |
x1 = float(x_val * current_conversion_rate) | |
y1 = float(y_val) | |
x2 = x1 + nudge | |
y2 = ( | |
sum( | |
s_curve( | |
((x2 / current_conversion_rate) * current_channel_proportion), | |
current_s_curve_params["power"], | |
current_s_curve_params["K"], | |
current_s_curve_params["b"], | |
current_s_curve_params["a"], | |
current_s_curve_params["x0"], | |
) | |
) | |
+ channel_correction | |
) | |
mroi_val = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0 | |
# Calculate ROI | |
roi_val = y_val / (x_val * current_conversion_rate) | |
return roi_val, mroi_val, y_val | |
# Function to find segment value | |
def find_segment_value(x, roi, mroi, roi_threshold=1, mroi_threshold=0.05): | |
# Initialize the start and end values of the x array | |
start_value = x[0] | |
end_value = x[-1] | |
# Define the condition for the "green region" where both ROI and MROI exceed their respective thresholds | |
green_condition = (roi > roi_threshold) & (mroi > mroi_threshold) | |
# Find indices where ROI exceeds the ROI threshold | |
left_indices = np.where(roi > roi_threshold)[0] | |
# Find indices where both ROI and MROI exceed their thresholds (green condition) | |
right_indices = np.where(green_condition)[0] | |
# Determine the left value based on the first index where ROI exceeds the threshold | |
left_value = x[left_indices[0]] if left_indices.size > 0 else x[0] | |
# Determine the right value based on the last index where both ROI and MROI exceed their thresholds | |
right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0] | |
# Ensure the left value does not exceed the right value, adjust if necessary | |
if left_value > right_value: | |
left_value = right_value | |
return start_value, end_value, left_value, right_value | |
# Function to generate response curves plots | |
def generate_response_curve_plots( | |
channel_list, | |
s_curve_params, | |
channels_proportion, | |
original_scenario_data, | |
multiplier, | |
): | |
figures, channel_roi_mroi, region_start_end = [], {}, {} | |
for channel in channel_list: | |
spends_actual = original_scenario_data["channels"][channel][ | |
"actual_total_spends" | |
] | |
conversion_rate = original_scenario_data["channels"][channel]["conversion_rate"] | |
channel_correction = sum( | |
original_scenario_data["channels"][channel]["correction"] | |
) | |
x_actual = np.linspace(0, 5 * spends_actual, 100) | |
x_plot = x_actual * conversion_rate | |
# Calculate y values for the S-curve | |
y_plot = [ | |
sum( | |
s_curve( | |
(x * channels_proportion[channel]), | |
s_curve_params[channel]["power"], | |
s_curve_params[channel]["K"], | |
s_curve_params[channel]["b"], | |
s_curve_params[channel]["a"], | |
s_curve_params[channel]["x0"], | |
) | |
) | |
+ channel_correction | |
for x in x_actual | |
] | |
# Calculate ROI and ensure they are scalar values | |
roi = [float(y) / float(x) if x != 0 else 0 for x, y in zip(x_plot, y_plot)] | |
# Calculate MROI using a small nudge | |
nudge = 1e-3 | |
mroi = [] | |
for i in range(len(x_plot)): | |
x1 = float(x_plot[i]) | |
y1 = float(y_plot[i]) | |
x2 = x1 + nudge | |
y2 = ( | |
sum( | |
s_curve( | |
((x2 / conversion_rate) * channels_proportion[channel]), | |
s_curve_params[channel]["power"], | |
s_curve_params[channel]["K"], | |
s_curve_params[channel]["b"], | |
s_curve_params[channel]["a"], | |
s_curve_params[channel]["x0"], | |
) | |
) | |
+ channel_correction | |
) | |
mroi_value = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0 | |
mroi.append(mroi_value) | |
# Channel correction | |
channel_correction = sum( | |
original_scenario_data["channels"][channel]["correction"] | |
) | |
# Calculate y, ROI and MROI for the actual spend point | |
roi_actual, mroi_actual, y_actual = get_point_parms( | |
spends_actual, | |
s_curve_params[channel], | |
channels_proportion[channel], | |
conversion_rate, | |
channel_correction, | |
) | |
# Create the plotly figure | |
fig = go.Figure() | |
# Add S-curve line | |
fig.add_trace( | |
go.Scatter( | |
x=np.array(x_plot) / multiplier, | |
y=np.array(y_plot) / multiplier, | |
mode="lines", | |
name="Metrics", | |
hoverinfo="text", | |
text=[ | |
f"Spends: {numerize(x / multiplier)}<br>{metrics_selected_formatted}: {numerize(y / multiplier)}<br>ROI: {r:.2f}<br>MROI: {m:.2f}" | |
for x, y, r, m in zip(x_plot, y_plot, roi, mroi) | |
], | |
) | |
) | |
# Add current spend point | |
fig.add_trace( | |
go.Scatter( | |
x=[spends_actual * conversion_rate / multiplier], | |
y=[y_actual / multiplier], | |
mode="markers", | |
marker=dict(color="cyan", size=10, symbol="circle"), | |
name="Actual Spend", | |
hoverinfo="text", | |
text=[ | |
f"Actual Spend: {numerize(spends_actual * conversion_rate / multiplier)}<br>{metrics_selected_formatted}: {numerize(y_actual / multiplier)}<br>ROI: {roi_actual:.2f}<br>MROI: {mroi_actual:.2f}" | |
], | |
showlegend=True, | |
) | |
) | |
# ROI Threshold | |
roi_threshold = st.session_state.roi_threshold | |
# Scale x and y values | |
x, y = np.array(x_plot), np.array(y_plot) | |
x_scaled, y_scaled = x / max(x), y / max(y) | |
# Calculate MROI scaled starting from the first point | |
mroi_scaled = np.zeros_like(x_scaled) | |
for j in range(1, len(x_scaled)): | |
x1, y1 = x_scaled[j - 1], y_scaled[j - 1] | |
x2, y2 = x_scaled[j], y_scaled[j] | |
mroi_scaled[j] = (y2 - y1) / (x2 - x1) if (x2 - x1) != 0 else 0 | |
# Get the start_value, end_value, left_value, right_value for segments | |
start_value, end_value, left_value, right_value = find_segment_value( | |
x_plot, np.array(roi), mroi_scaled, roi_threshold, mroi_threshold | |
) | |
# Store region start and end points | |
region_start_end[channel] = { | |
"start_value": start_value, | |
"end_value": end_value, | |
"left_value": left_value, | |
"right_value": right_value, | |
} | |
# Adding background colors | |
y_max = max(y_plot) * 1.3 # 30% extra space above the max | |
# Yellow region | |
fig.add_shape( | |
type="rect", | |
x0=start_value / multiplier, | |
y0=0, | |
x1=left_value / multiplier, | |
y1=y_max / multiplier, | |
line=dict(width=0), | |
fillcolor="rgba(255, 255, 0, 0.3)", | |
layer="below", | |
) | |
# Green region | |
fig.add_shape( | |
type="rect", | |
x0=left_value / multiplier, | |
y0=0, | |
x1=right_value / multiplier, | |
y1=y_max / multiplier, | |
line=dict(width=0), | |
fillcolor="rgba(0, 255, 0, 0.3)", | |
layer="below", | |
) | |
# Red region | |
fig.add_shape( | |
type="rect", | |
x0=right_value / multiplier, | |
y0=0, | |
x1=end_value / multiplier, | |
y1=y_max / multiplier, | |
line=dict(width=0), | |
fillcolor="rgba(255, 0, 0, 0.3)", | |
layer="below", | |
) | |
# Layout adjustments | |
fig.update_layout( | |
title=f"{name_formating(channel)}", | |
showlegend=False, | |
xaxis=dict( | |
showgrid=True, | |
showticklabels=True, | |
tickformat=".2s", | |
gridcolor="lightgrey", | |
gridwidth=0.5, | |
griddash="dot", | |
), | |
yaxis=dict( | |
showgrid=True, | |
showticklabels=True, | |
tickformat=".2s", | |
gridcolor="lightgrey", | |
gridwidth=0.5, | |
griddash="dot", | |
), | |
template="plotly_white", | |
margin=dict(l=20, r=20, t=30, b=20), | |
height=100 * (len(channel_list) + 4 - 1) // 4, | |
) | |
figures.append(fig) | |
# Store data of each channel ROI and MROI | |
channel_roi_mroi[channel] = { | |
"actual_roi": roi_actual, | |
"actual_mroi": mroi_actual, | |
} | |
return figures, channel_roi_mroi, region_start_end | |
# Function to add modified spends/metrics point on plot | |
def modified_metrics_point( | |
fig, | |
modified_spends, | |
s_curve_params, | |
channels_proportion, | |
conversion_rate, | |
channel_correction, | |
): | |
# Calculate ROI, MROI, and y for the modified point | |
roi_modified, mroi_modified, y_modified = get_point_parms( | |
modified_spends, | |
s_curve_params, | |
channels_proportion, | |
conversion_rate, | |
channel_correction, | |
) | |
# Add modified spend point | |
fig.add_trace( | |
go.Scatter( | |
x=[modified_spends * conversion_rate / st.session_state["multiplier"]], | |
y=[y_modified / st.session_state["multiplier"]], | |
mode="markers", | |
marker=dict(color="blueviolet", size=10, symbol="circle"), | |
name="Optimized Spend", | |
hoverinfo="text", | |
text=[ | |
f"Modified Spend: {numerize(modified_spends * conversion_rate / st.session_state.multiplier)}<br>{metrics_selected_formatted}: {numerize(y_modified / st.session_state.multiplier)}<br>ROI: {roi_modified:.2f}<br>MROI: {mroi_modified:.2f}" | |
], | |
showlegend=True, | |
) | |
) | |
return roi_modified, mroi_modified, fig | |
# Function to update bound type change | |
def bound_type_change(): | |
# Get updated bound type from session state | |
new_bound_type = st.session_state["bound_type_key"] | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
# Update the bound type | |
data[metrics_selected][panel_selected]["bound_type"] = new_bound_type | |
# Set bounds to default value if bound type is False (Default) | |
channel_list = list(data[metrics_selected][panel_selected]["channels"].keys()) | |
if not new_bound_type: | |
for channel in channel_list: | |
data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = [ | |
-10, | |
10, | |
] | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data | |
# Function to format the numbers with decimal | |
def format_value(input_value): | |
value = abs(input_value) | |
return f"{input_value:.4f}" if value < 1 else f"{numerize(input_value, 1)}" | |
# Function to format the numbers with decimal | |
def round_value(input_value): | |
value = abs(input_value) | |
return round(input_value, 4) if value < 1 else round(input_value, 1) | |
# Function to generate ROI and MROI plots for all channels | |
def roi_mori_plot(channel_roi_mroi): | |
# Dictionary to store plots | |
channel_roi_mroi_plot = {} | |
for channel in channel_roi_mroi: | |
channel_roi_mroi_data = channel_roi_mroi[channel] | |
# Extract the data | |
actual_roi = channel_roi_mroi_data["actual_roi"] | |
optimized_roi = channel_roi_mroi_data["optimized_roi"] | |
actual_mroi = channel_roi_mroi_data["actual_mroi"] | |
optimized_mroi = channel_roi_mroi_data["optimized_mroi"] | |
# Plot ROI | |
fig_roi = go.Figure() | |
fig_roi.add_trace( | |
go.Bar( | |
x=["Actual ROI"], | |
y=[actual_roi], | |
name="Actual ROI", | |
marker_color="cyan", | |
width=1, | |
text=[format_value(actual_roi)], | |
textposition="auto", | |
textfont=dict(color="black", size=14), | |
) | |
) | |
fig_roi.add_trace( | |
go.Bar( | |
x=["Optimized ROI"], | |
y=[optimized_roi], | |
name="Optimized ROI", | |
marker_color="blueviolet", | |
width=1, | |
text=[format_value(optimized_roi)], | |
textposition="auto", | |
textfont=dict(color="black", size=14), | |
) | |
) | |
fig_roi.update_layout( | |
annotations=[ | |
dict( | |
x=0.5, | |
y=1.3, | |
xref="paper", | |
yref="paper", | |
text="ROI", | |
showarrow=False, | |
font=dict(size=14), | |
) | |
], | |
barmode="group", | |
bargap=0, | |
showlegend=False, | |
width=110, | |
height=110, | |
xaxis=dict( | |
showticklabels=True, | |
showgrid=False, | |
tickangle=0, | |
ticktext=["Actual", "Optimized"], | |
tickvals=["Actual ROI", "Optimized ROI"], | |
), | |
yaxis=dict(showticklabels=False, showgrid=False), | |
margin=dict(t=20, b=20, r=0, l=0), | |
) | |
# Plot MROI | |
fig_mroi = go.Figure() | |
fig_mroi.add_trace( | |
go.Bar( | |
x=["Actual MROI"], | |
y=[actual_mroi], | |
name="Actual MROI", | |
marker_color="cyan", | |
width=1, | |
text=[format_value(actual_mroi)], | |
textposition="auto", | |
textfont=dict(color="black", size=14), | |
) | |
) | |
fig_mroi.add_trace( | |
go.Bar( | |
x=["Optimized MROI"], | |
y=[optimized_mroi], | |
name="Optimized MROI", | |
marker_color="blueviolet", | |
width=1, | |
text=[format_value(optimized_mroi)], | |
textposition="auto", | |
textfont=dict(color="black", size=14), | |
) | |
) | |
fig_mroi.update_layout( | |
annotations=[ | |
dict( | |
x=0.5, | |
y=1.3, | |
xref="paper", | |
yref="paper", | |
text="MROI", | |
showarrow=False, | |
font=dict(size=14), | |
) | |
], | |
barmode="group", | |
bargap=0, | |
showlegend=False, | |
width=110, | |
height=110, | |
xaxis=dict( | |
showticklabels=True, | |
showgrid=False, | |
tickangle=0, | |
ticktext=["Actual", "Optimized"], | |
tickvals=["Actual MROI", "Optimized MROI"], | |
), | |
yaxis=dict(showticklabels=False, showgrid=False), | |
margin=dict(t=20, b=20, r=0, l=0), | |
) | |
# Store plots | |
channel_roi_mroi_plot[channel] = {"fig_roi": fig_roi, "fig_mroi": fig_mroi} | |
return channel_roi_mroi_plot | |
# Function to save the current scenario with the mentioned name | |
def save_scenario( | |
scenario_dict, | |
metrics_selected, | |
panel_selected, | |
optimization_goal, | |
channel_roi_mroi, | |
timeframe, | |
multiplier, | |
): | |
# Remove extra space at start and ends | |
if st.session_state["scenario_name"] is not None: | |
st.session_state["scenario_name"] = st.session_state["scenario_name"].strip() | |
if ( | |
st.session_state["scenario_name"] is None | |
or st.session_state["scenario_name"] == "" | |
): | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Please provide a name to save the scenario.", | |
"icon": "⚠️", | |
} | |
return | |
# Check the scenario name | |
if not ( | |
word_length_limit_lower | |
<= len(st.session_state["scenario_name"]) | |
<= word_length_limit_upper | |
and bool(re.match("^[A-Za-z0-9_]*$", st.session_state["scenario_name"])) | |
): | |
# Store the warning message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": f"Please provide a valid scenario name ({word_length_limit_lower}-{word_length_limit_upper} characters, only A-Z, a-z, 0-9, and _).", | |
"icon": "⚠️", | |
} | |
return | |
# Check if the dictionary is empty | |
if not scenario_dict: | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Nothing to save. The scenario data is empty.", | |
"icon": "⚠️", | |
} | |
return | |
# Add additional scenario details | |
scenario_dict["panel_selected"] = panel_selected | |
scenario_dict["metrics_selected"] = metrics_selected | |
scenario_dict["optimization"] = optimization_goal | |
scenario_dict["channel_roi_mroi"] = channel_roi_mroi | |
scenario_dict["timeframe"] = timeframe | |
scenario_dict["multiplier"] = multiplier | |
# Load existing scenarios | |
saved_scenarios_dict = st.session_state["project_dct"]["saved_scenarios"][ | |
"saved_scenarios_dict" | |
] | |
# Check if the name is already taken | |
if st.session_state["scenario_name"] in saved_scenarios_dict.keys(): | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Name already exists. Please change the name or delete the existing scenario from the Saved Scenario page.", | |
"icon": "⚠️", | |
} | |
return | |
# Update the dictionary with the new scenario | |
saved_scenarios_dict[st.session_state["scenario_name"]] = scenario_dict | |
# Update the updated dictionary | |
st.session_state["project_dct"]["saved_scenarios"][ | |
"saved_scenarios_dict" | |
] = saved_scenarios_dict | |
# Update DB | |
update_db( | |
prj_id=st.session_state["project_number"], | |
page_nam="Scenario Planner", | |
file_nam="project_dct", | |
pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
schema=schema, | |
) | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "success", | |
"message": f"Scenario '{st.session_state.scenario_name}' has been successfully saved!", | |
"icon": "💾", | |
} | |
st.toast( | |
f"Scenario '{st.session_state.scenario_name}' has been successfully saved!", | |
icon="💾", | |
) | |
# Clear the scenario name input | |
st.session_state["scenario_name"] = "" | |
# Function to calculate the RGBA color code based on the spends value and region boundaries | |
def calculate_rgba(spends_value, region_start_end): | |
# Get region start and end points | |
start_value = region_start_end["start_value"] | |
end_value = region_start_end["end_value"] | |
left_value = region_start_end["left_value"] | |
right_value = region_start_end["right_value"] | |
# Calculate alpha dynamically based on the position within the range | |
def calculate_alpha(position, start, end, min_alpha=0.1, max_alpha=0.4): | |
return min_alpha + (max_alpha - min_alpha) * (position - start) / (end - start) | |
if start_value <= spends_value <= left_value: | |
# Yellow range (0, 128, 0) - More transparent towards left, darker towards start | |
alpha = calculate_alpha(spends_value, left_value, start_value) | |
return (255, 255, 0, alpha) # RGB for yellow | |
elif left_value < spends_value <= right_value: | |
# Green range (0, 128, 0) - More transparent towards right, darker towards left | |
alpha = calculate_alpha(spends_value, right_value, left_value) | |
return (0, 128, 0, alpha) # RGB for green | |
elif right_value < spends_value <= end_value: | |
# Red range (255, 0, 0) - More transparent towards right, darker towards end | |
alpha = calculate_alpha(spends_value, right_value, end_value) | |
return (255, 0, 0, alpha) # RGB for red | |
# Function to format and display the channel name with a color and background color | |
def display_channel_name_with_background_color( | |
channel_name, background_color=(0, 128, 0, 0.1) | |
): | |
formatted_name = name_formating(channel_name) | |
# Unpack the RGBA values | |
r, g, b, a = background_color | |
# Create the HTML content with specified background color | |
html_content = f""" | |
<div style=" | |
background-color: rgba({r}, {g}, {b}, {a}); | |
padding: 10px; | |
display: inline-block; | |
border-radius: 5px;"> | |
<strong>{formatted_name}</strong> | |
</div> | |
""" | |
return html_content | |
# Function to check optimization success | |
def check_optimization_success( | |
channel_list, | |
input_channels_spends, | |
output_channels_spends, | |
bounds_dict, | |
optimization_goal, | |
modified_total_metrics, | |
actual_total_metrics, | |
modified_total_spends, | |
actual_total_spends, | |
original_total_spends, | |
optimization_success, | |
): | |
for channel in channel_list: | |
input_channel_spends = input_channels_spends[channel] | |
output_channel_spends = abs(output_channels_spends[channel]) | |
lower_percent = bounds_dict[channel][0] | |
upper_percent = bounds_dict[channel][1] | |
lower_allowed_value = max( | |
(input_channel_spends * (100 + lower_percent - 1) / 100), 0 | |
) # 1% Tolerance | |
upper_allowed_value = max( | |
(input_channel_spends * (100 + upper_percent + 1) / 100), 10 | |
) # 1% Tolerance | |
# Check if output spends are within allowed bounds | |
if ( | |
output_channel_spends > upper_allowed_value | |
or output_channel_spends < lower_allowed_value | |
): | |
error_message = "Optimization failed: strict bounds. Use flexible bounds." | |
return False, error_message, "❌" | |
# Check optimization goal and percent change | |
if optimization_goal == "Spend": | |
percent_change_happened = abs( | |
(modified_total_spends - actual_total_spends) / actual_total_spends | |
) | |
if percent_change_happened > 0.01: # Greater than 1% Tolerance | |
error_message = "Optimization failed: input and optimized spends differ. Use flexible bounds." | |
return False, error_message, "❌" | |
else: | |
percent_change_happened = abs( | |
(modified_total_metrics - actual_total_metrics) / actual_total_metrics | |
) | |
if percent_change_happened > 0.01: # Greater than 1% Tolerance | |
error_message = "Optimization failed: input and optimized metrics differ. Use flexible bounds." | |
return False, error_message, "❌" | |
# Define the allowable range for new spends | |
lower_limit = original_total_spends * 0.5 | |
upper_limit = original_total_spends * 1.5 | |
# Check if the new spends are within the allowed range | |
if modified_total_spends < lower_limit or modified_total_spends > upper_limit: | |
error_message = "New spends optimized are outside the allowed range of ±50%." | |
return False, error_message, "❌" | |
# Check if the optimization failed to converge | |
if not optimization_success: | |
error_message = "Optimization failed to converge." | |
return False, error_message, "❌" | |
return True, "Optimization successful.", "💸" | |
# Function to check if the optimization target is achievable within the given bounds | |
def check_target_achievability( | |
optimize_allow, | |
optimization_goal, | |
lower_achievable_target, | |
upper_achievable_target, | |
total_absolute_target, | |
): | |
# Format the messages with appropriate numerization and naming | |
given_input = "response metric" if optimization_goal == "Spend" else "spends" | |
# Combined achievable message | |
achievable_message = ( | |
f"Achievable {optimization_goal} with the given {given_input} and bounds ranges from " | |
f"{numerize(lower_achievable_target / st.session_state.multiplier)} to " | |
f"{numerize(upper_achievable_target / st.session_state.multiplier)}" | |
) | |
# Check if the target is within achievable bounds | |
if (lower_achievable_target > total_absolute_target) or ( | |
upper_achievable_target < total_absolute_target | |
): | |
# Update session state with the error message | |
st.session_state.message_display = { | |
"type": "error", | |
"message": achievable_message, | |
"icon": "⚠️", | |
} | |
optimize_allow = False | |
elif (st.session_state.message_display["message"] is not None) and ( | |
str(st.session_state.message_display["message"]).startswith("Achievable") | |
): | |
# Clear message_display | |
st.session_state.message_display = { | |
"type": "success", | |
"message": None, | |
"icon": "", | |
} | |
optimize_allow = True | |
return optimize_allow | |
# Function to display a message with the appropriate type and icon | |
def display_message(): | |
# Retrieve the message details from the session state | |
message_type = st.session_state.message_display["type"] | |
message = st.session_state.message_display["message"] | |
icon = st.session_state.message_display["icon"] | |
# Display the message if it exists | |
if message is not None: | |
if message_type == "success": | |
st.success(message, icon=icon) | |
# Log message | |
log_message("info", message, "Scenario Planner") | |
elif message_type == "warning": | |
st.warning(message, icon=icon) | |
# Log message | |
log_message("warning", message, "Scenario Planner") | |
elif message_type == "error": | |
st.error(message, icon=icon) | |
# Log message | |
log_message("error", message, "Scenario Planner") | |
else: | |
st.info(message, icon=icon) | |
# Log message | |
log_message("info", message, "Scenario Planner") | |
# Function to change bounds for all channels | |
def all_bound_change(channel_list, apply_all=False): | |
# Fetch updated upper and lower bounds for all channels | |
all_lower_bound = st.session_state["all_lower_key"] | |
all_upper_bound = st.session_state["all_upper_key"] | |
# Check if lower bound is not greater than upper bound | |
if all_lower_bound < all_upper_bound: | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] | |
# Update the bounds for the all channels | |
if apply_all: | |
for channel in channel_list: | |
if not data[metrics_selected][panel_selected]["channels"][channel][ | |
"freeze" | |
]: | |
data[metrics_selected][panel_selected]["channels"][channel][ | |
"bounds" | |
] = [ | |
all_lower_bound, | |
all_upper_bound, | |
] | |
# Update the bounds for the all channels holder | |
data[metrics_selected][panel_selected]["bounds"] = [ | |
all_lower_bound, | |
all_upper_bound, | |
] | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] = data | |
else: | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Lower bound cannot be greater than Upper bound.", | |
"icon": "⚠️", | |
} | |
return | |
try: | |
# Page Title | |
st.title("Scenario Planner") | |
# Retrieve the list of all metric names from the specified directory | |
metrics_list = get_metrics_names() | |
# Check if there are any metrics available in the metrics list | |
if not metrics_list: | |
# Display a warning message to the user if no metrics are found | |
st.warning( | |
"Please tune at least one model to generate response curves data.", | |
icon="⚠️", | |
) | |
# Log message | |
log_message( | |
"warning", | |
"Please tune at least one model to generate response curves data.", | |
"Scenario Planner", | |
) | |
st.stop() | |
# Widget columns | |
metric_col, panel_col, timeframe_col, save_progress_col = st.columns(4) | |
# Metrics Selection | |
metrics_selected = metric_col.selectbox( | |
"Response Metrics", | |
sorted(metrics_list), | |
format_func=name_formating, | |
key="response_metrics_selectbox_sp", | |
index=0, | |
) | |
metrics_selected_formatted = name_formating(metrics_selected) | |
# Retrieve the list of all panel names for specified Metrics | |
panel_list = get_panels_names(metrics_selected) | |
# Panel Selection | |
panel_selected = panel_col.selectbox( | |
"Panel", | |
sorted(panel_list), | |
format_func=name_formating, | |
key="panel_selected_selectbox_sp", | |
index=0, | |
) | |
panel_selected_formatted = name_formating(panel_selected) | |
# Timeframe Selection | |
timeframe_selected = timeframe_col.selectbox( | |
"Timeframe", | |
["Input Data Range", "Yearly", "Quarterly", "Monthly"], | |
key="timeframe_selected_selectbox_sp", | |
index=0, | |
) | |
# Check if the RCS metadata file does not exist | |
if ( | |
st.session_state["project_dct"]["response_curves"]["original_metadata_file"] | |
is None | |
or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] | |
is None | |
): | |
# RCS metadata file does not exist. Generating new RCS data | |
generate_rcs_data() | |
# Log message | |
log_message( | |
"info", | |
"RCS metadata file does not exist. Generating new RCS data.", | |
"Scenario Planner", | |
) | |
# Load rcs metadata files if they exist | |
original_rcs_data, modified_rcs_data = load_rcs_metadata_files() | |
# Check if the scenario metadata file does not exist | |
if ( | |
st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"] | |
is None | |
or st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] | |
is None | |
): | |
# Scenario file does not exist. Generating new senario file data | |
generate_scenario_data() | |
# Load scenario metadata files if they exist | |
original_data, modified_data = load_scenario_metadata_files() | |
try: | |
# Data date range | |
date_range = pd.to_datetime( | |
list(original_data[metrics_selected][panel_selected]["channels"].values())[ | |
0 | |
]["dates"] | |
) | |
# Calculate the number of days between max and min dates | |
date_diff = pd.Series(date_range).diff() | |
day_data = int( | |
(date_range.max() - date_range.min()).days | |
+ (6 if date_diff.value_counts().idxmax() == pd.Timedelta(weeks=1) else 0) | |
) | |
# Set the multiplier based on the selected timeframe | |
if timeframe_selected == "Input Data Range": | |
st.session_state["multiplier"] = 1 | |
elif timeframe_selected == "Yearly": | |
st.session_state["multiplier"] = day_data / 365 | |
elif timeframe_selected == "Quarterly": | |
st.session_state["multiplier"] = day_data / 90 | |
elif timeframe_selected == "Monthly": | |
st.session_state["multiplier"] = day_data / 30 | |
except: | |
st.session_state["multiplier"] = 1 | |
# Extract original scenario data for the selected metric and panel | |
original_scenario_data = original_data[metrics_selected][panel_selected] | |
# Extract modified scenario data for the same metric and panel | |
modified_scenario_data = modified_data[metrics_selected][panel_selected] | |
# Display Actual Vs Optimized | |
st.divider() | |
( | |
actual_spends_col, | |
actual_metrics_col, | |
actual_CPA_col, | |
base_col, | |
optimized_spends_col, | |
optimized_metrics_col, | |
optimized_CPA_col, | |
) = st.columns([1, 1, 1, 1, 1.5, 1.5, 1.5]) | |
# Base Contribution | |
base_contribution = ( | |
sum(original_scenario_data["constant"]) / st.session_state["multiplier"] | |
) | |
# Display Base Metric | |
base_col.metric( | |
f"Base {metrics_selected_formatted}", | |
numerize(base_contribution), | |
) | |
# Extracting and formatting values | |
actual_spends = numerize( | |
original_scenario_data["actual_total_spends"] / st.session_state["multiplier"] | |
) | |
actual_metric_value = numerize( | |
original_scenario_data["actual_total_sales"] / st.session_state["multiplier"] | |
) | |
optimized_spends = numerize( | |
modified_scenario_data["modified_total_spends"] / st.session_state["multiplier"] | |
) | |
optimized_metric_value = numerize( | |
modified_scenario_data["modified_total_sales"] / st.session_state["multiplier"] | |
) | |
# Calculate the deltas (differences) for spends and metrics | |
spends_delta_value = ( | |
modified_scenario_data["modified_total_spends"] | |
- original_scenario_data["actual_total_spends"] | |
) / st.session_state["multiplier"] | |
metrics_delta_value = ( | |
modified_scenario_data["modified_total_sales"] | |
- original_scenario_data["actual_total_sales"] | |
) / st.session_state["multiplier"] | |
# Calculate the percentage changes for spends and metrics | |
spends_percentage_change = ( | |
spends_delta_value | |
/ ( | |
original_scenario_data["actual_total_spends"] | |
/ st.session_state["multiplier"] | |
) | |
) * 100 | |
metrics_percentage_change_media = ( | |
metrics_delta_value | |
/ ( | |
( | |
original_scenario_data["actual_total_sales"] | |
/ st.session_state["multiplier"] | |
) | |
- base_contribution | |
) | |
) * 100 | |
metrics_percentage_change_all = ( | |
metrics_delta_value | |
/ ( | |
original_scenario_data["actual_total_sales"] | |
/ st.session_state["multiplier"] | |
) | |
) * 100 | |
# Format the percentage change for display | |
spends_percentage_display = ( | |
f"({round(spends_percentage_change, 1)}%)" | |
if abs(spends_percentage_change) >= 0.1 | |
else "(0%)" | |
) | |
metrics_percentage_display_media = ( | |
f"({round(metrics_percentage_change_media, 1)}%)" | |
if abs(metrics_percentage_change_media) >= 0.1 | |
else "(0%)" | |
) | |
metrics_percentage_display_all = ( | |
f"({round(metrics_percentage_change_all, 1)}%)" | |
if abs(metrics_percentage_change_all) >= 0.1 | |
else "(0%)" | |
) | |
# Check if the delta for spends is less than 0.1% in absolute terms | |
if abs(spends_delta_value) < 0.001 * original_scenario_data["actual_total_spends"]: | |
spends_delta = "0" | |
else: | |
spends_delta = numerize(spends_delta_value) | |
# Check if the delta for metrics is less than 0.1% in absolute terms | |
if abs(metrics_delta_value) < 0.001 * original_scenario_data["actual_total_sales"]: | |
metrics_delta = "0" | |
else: | |
metrics_delta = numerize(metrics_delta_value) | |
# Display current and optimized CPA | |
actual_CPA = ( | |
original_scenario_data["actual_total_spends"] | |
/ original_scenario_data["actual_total_sales"] | |
) | |
optimized_CPA = ( | |
modified_scenario_data["modified_total_spends"] | |
/ modified_scenario_data["modified_total_sales"] | |
) | |
CPA_delta_value = optimized_CPA - actual_CPA | |
# Calculate the percentage change for CPA | |
CPA_percentage_change = ( | |
((CPA_delta_value / actual_CPA) * 100) if actual_CPA != 0 else 0 | |
) | |
CPA_percentage_display = ( | |
f"({round(CPA_percentage_change, 1)}%)" | |
if abs(CPA_percentage_change) >= 0.1 | |
else "(0%)" | |
) | |
# Check if the CPA delta is less than 0.1% in absolute terms | |
if abs(CPA_delta_value) < 0.001 * actual_CPA: | |
CPA_delta = "0" | |
else: | |
CPA_delta = round_value(CPA_delta_value) | |
# Display the metrics with percentage changes | |
actual_CPA_col.metric( | |
"Actual CPA", | |
(numerize(actual_CPA) if actual_CPA >= 1000 else round_value(actual_CPA)), | |
) | |
optimized_spends_col.metric( | |
"Optimized Spend", | |
f"{optimized_spends} {spends_percentage_display}", | |
delta=spends_delta, | |
) | |
optimized_metrics_col.metric( | |
f"Optimized {metrics_selected_formatted}", | |
f"{optimized_metric_value} {metrics_percentage_display_all}", | |
delta=f"{metrics_delta} {metrics_percentage_display_media}", | |
) | |
optimized_CPA_col.metric( | |
"Optimized CPA", | |
( | |
f"{numerize(optimized_CPA) if optimized_CPA >= 1000 else round_value(optimized_CPA)} {CPA_percentage_display}" | |
), | |
delta=CPA_delta, | |
delta_color="inverse", | |
) | |
# Displaying metrics in the columns | |
actual_spends_col.metric("Actual Spend", actual_spends) | |
actual_metrics_col.metric( | |
f"Actual {metrics_selected_formatted}", | |
actual_metric_value, | |
) | |
# Check if the percentage display for media starts with a negative sign | |
if str(metrics_percentage_display_all[1:]).startswith("-"): | |
# If negative, set the color to red | |
metrics_percentage_display_media_str = f'<span style="color:rgb(255, 43, 43)">red <strong>{metrics_percentage_display_media}</strong></span>' | |
else: | |
# If positive, set the color to green | |
metrics_percentage_display_media_str = f'<span style="color:rgb(9, 171, 59)">green <strong>{metrics_percentage_display_media}</strong></span>' | |
# Display percentage calculation note | |
st.markdown( | |
f"**Note:** The percentage change for the response metric in {metrics_percentage_display_media_str} reflects the change based on the media-driven portion only, excluding the fixed base contribution and the percentage in black **{metrics_percentage_display_all}** represents the change based on the total response metric, including the base contribution. For spends, the percentage change **{spends_percentage_display}** is based on the total actual spends (base spends are always zero).", | |
unsafe_allow_html=True, | |
) | |
# Divider | |
st.divider() | |
# Calculate ROI threshold | |
st.session_state.roi_threshold = ( | |
original_scenario_data["actual_total_sales"] | |
- sum(original_scenario_data["constant"]) | |
) / original_scenario_data["actual_total_spends"] | |
# Fetch and sort channels based on actual spends | |
channel_list = list( | |
sorted( | |
original_scenario_data["channels"], | |
key=lambda channel: ( | |
original_scenario_data["channels"][channel]["actual_total_spends"] | |
* original_scenario_data["channels"][channel]["conversion_rate"] | |
), | |
reverse=True, | |
) | |
) | |
# Create columns for optimization goal and buttons | |
( | |
optimization_goal_col, | |
message_display_col, | |
button_col, | |
bounds_col, | |
) = st.columns([3, 6, 3, 3]) | |
# Display spinnner or message | |
with message_display_col: | |
st.write("###") | |
spinner_placeholder = st.empty() | |
# Save Progress | |
with save_progress_col: | |
st.write("####") # Padding | |
save_progress_placeholder = st.empty() | |
# Save page progress | |
with spinner_placeholder, st.spinner("Saving Progress ..."): | |
if save_progress_placeholder.button("Save Progress", use_container_width=True): | |
# Update DB | |
update_db( | |
prj_id=st.session_state["project_number"], | |
page_nam="Scenario Planner", | |
file_nam="project_dct", | |
pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
schema=schema, | |
) | |
# Store the message details in session state | |
with message_display_col: | |
st.session_state.message_display = { | |
"type": "success", | |
"message": "Progress saved successfully!", | |
"icon": "💾", | |
} | |
st.toast("Progress saved successfully!", icon="💾") | |
# Create columns for absolute text, slider, percentage number and bound type | |
absolute_text_col, absolute_slider_col, percentage_number_col, all_bounds_col = ( | |
st.columns([2, 4, 2, 2]) | |
) | |
# Dropdown for selecting optimization goal | |
optimization_goal = optimization_goal_col.selectbox( | |
"Fix", ["Spend", metrics_selected_formatted] | |
) | |
# Button columns with padding for alignment | |
with button_col: | |
st.write("##") # Padding | |
optimize_button_col, reset_button_col = st.columns(2) | |
reset_button_col.button( | |
"Reset", | |
use_container_width=True, | |
on_click=reset_scenario, | |
args=(metrics_selected, panel_selected), | |
) | |
# Absolute value display | |
if optimization_goal == "Spend": | |
absolute_value = modified_scenario_data["actual_total_spends"] | |
st.session_state.total_absolute_main_key = numerize( | |
modified_scenario_data["modified_total_spends"] | |
/ st.session_state["multiplier"] | |
) | |
else: | |
absolute_value = modified_scenario_data["actual_total_sales"] | |
st.session_state.total_absolute_main_key = numerize( | |
modified_scenario_data["modified_total_sales"] | |
/ st.session_state["multiplier"] | |
) | |
total_absolute = absolute_text_col.text_input( | |
"Absolute", | |
key="total_absolute_main_key", | |
on_change=total_absolute_main_key_change, | |
args=( | |
metrics_selected, | |
panel_selected, | |
optimization_goal, | |
), | |
) | |
# Generate and process slider options | |
slider_options = list( | |
np.linspace(int(0.5 * absolute_value), int(1.5 * absolute_value), 50) | |
) # Generate range | |
slider_options.append( | |
modified_scenario_data["modified_total_spends"] | |
if optimization_goal == "Spend" | |
else modified_scenario_data["modified_total_sales"] | |
) | |
slider_options = sorted(slider_options) # Sort the list | |
numerized_slider_options = [ | |
numerize(value / st.session_state["multiplier"]) for value in slider_options | |
] # Numerize each value | |
# Slider for adjusting absolute value within a range | |
st.session_state.total_absolute_key = numerize( | |
modified_scenario_data["modified_total_spends"] / st.session_state["multiplier"] | |
if optimization_goal == "Spend" | |
else modified_scenario_data["modified_total_sales"] | |
/ st.session_state["multiplier"] | |
) | |
slider_value = absolute_slider_col.select_slider( | |
"Absolute", | |
numerized_slider_options, | |
key="total_absolute_key", | |
on_change=total_absolute_key_change, | |
args=( | |
metrics_selected, | |
panel_selected, | |
optimization_goal, | |
), | |
) | |
# Number input for percentage value | |
if optimization_goal == "Spend": | |
st.session_state.total_percentage_key = int( | |
round( | |
( | |
( | |
modified_scenario_data["modified_total_spends"] | |
- modified_scenario_data["actual_total_spends"] | |
) | |
/ modified_scenario_data["actual_total_spends"] | |
) | |
* 100, | |
0, | |
) | |
) | |
else: | |
st.session_state.total_percentage_key = int( | |
round( | |
( | |
( | |
modified_scenario_data["modified_total_sales"] | |
- modified_scenario_data["actual_total_sales"] | |
) | |
/ modified_scenario_data["actual_total_sales"] | |
) | |
* 100, | |
0, | |
) | |
) | |
percentage_target = percentage_number_col.number_input( | |
"Percentage", | |
min_value=-50, | |
max_value=50, | |
key="total_percentage_key", | |
on_change=total_percentage_key_change, | |
args=( | |
metrics_selected, | |
panel_selected, | |
absolute_value, | |
optimization_goal, | |
), | |
) | |
# Toggle input for bound type | |
st.session_state["bound_type_key"] = modified_scenario_data["bound_type"] | |
with bounds_col: | |
st.write("##") # Padding | |
# Columns for custom bounds toggle and apply all bounds button | |
allow_custom_bounds_col, apply_all_bounds_col = st.columns(2) | |
# Toggle for enabling/disabling custom bounds | |
bound_type = allow_custom_bounds_col.toggle( | |
"Bounds", | |
on_change=bound_type_change, | |
key="bound_type_key", | |
) | |
# Button to apply all bounds | |
apply_all_bounds = apply_all_bounds_col.button( | |
"Apply All", | |
use_container_width=True, | |
on_click=all_bound_change, | |
args=(channel_list, True), | |
disabled=not bound_type, | |
) | |
# Section for setting all lower and upper bounds | |
with all_bounds_col: | |
lower_bound_all, upper_bound_all = st.columns([1, 1]) | |
# Initialize session state keys for lower and upper bounds | |
st.session_state["all_lower_key"] = (modified_scenario_data["bounds"])[0] | |
st.session_state["all_upper_key"] = (modified_scenario_data["bounds"])[1] | |
# Input for all lower bounds | |
all_lower_bound = lower_bound_all.number_input( | |
"All Lower Bounds", | |
min_value=-100, | |
max_value=100, | |
key="all_lower_key", | |
on_change=all_bound_change, | |
args=(channel_list, False), | |
disabled=not bound_type, | |
) | |
# Input for all upper bounds | |
all_upper_bound = upper_bound_all.number_input( | |
"All Upper Bounds", | |
min_value=-100, | |
max_value=100, | |
key="all_upper_key", | |
on_change=all_bound_change, | |
args=(channel_list, False), | |
disabled=not bound_type, | |
) | |
# Collect inputs from the user interface | |
total_channel_spends, optimize_allow = 0, True | |
bounds_dict = {} | |
s_curve_params = {} | |
channels_spends = {} | |
channels_proportion = {} | |
channels_conversion_ratio = {} | |
channels_name_plot_placeholder = {} | |
# Optimization Inputs UI | |
with st.expander("Optimization Inputs", expanded=True): | |
# Initialize total contributions for actual and optimized spends and metrics | |
( | |
total_actual_spend_contribution, | |
total_actual_metric_contribution, | |
total_optimized_spend_contribution, | |
total_optimized_metric_contribution, | |
) = ( | |
0, | |
sum(modified_scenario_data["constant"]), | |
0, | |
sum(modified_scenario_data["constant"]), | |
) | |
# Iterate over each channel in the channel list | |
for channel in channel_list: | |
# Accumulate actual total spends | |
total_actual_spend_contribution += ( | |
modified_scenario_data["channels"][channel]["actual_total_spends"] | |
* modified_scenario_data["channels"][channel]["conversion_rate"] | |
) | |
# Accumulate actual total sales (metrics) | |
total_actual_metric_contribution += modified_scenario_data["channels"][ | |
channel | |
]["actual_total_sales"] | |
# Accumulate optimized total spends | |
total_optimized_spend_contribution += ( | |
modified_scenario_data["channels"][channel]["modified_total_spends"] | |
* modified_scenario_data["channels"][channel]["conversion_rate"] | |
) | |
# Accumulate optimized total sales (metrics) | |
total_optimized_metric_contribution += modified_scenario_data["channels"][ | |
channel | |
]["modified_total_sales"] | |
for channel in channel_list: | |
st.divider() | |
# Channel key | |
channel_key = f"{metrics_selected}_{panel_selected}_{channel}" | |
# Create columns | |
if st.session_state["bound_type_key"]: | |
( | |
name_plot_col, | |
input_col, | |
spends_col, | |
metrics_col, | |
bounds_input_col, | |
bounds_display_col, | |
allow_col, | |
) = st.columns([3, 2, 2, 2, 2, 2, 1]) | |
else: | |
( | |
name_plot_col, | |
input_col, | |
spends_col, | |
metrics_col, | |
bounds_display_col, | |
allow_col, | |
) = st.columns([1.5, 1, 1.5, 1.5, 1, 0.5]) | |
bounds_input_col = st.empty() | |
# Display channel name and ROI/MROI plot | |
with name_plot_col: | |
# Placeholder for channel name | |
channel_name_placeholder = st.empty() | |
channel_name_placeholder.markdown( | |
display_channel_name_with_background_color(channel), | |
unsafe_allow_html=True, | |
) | |
# Placeholder for ROI and MROI plot | |
channel_plot_placeholder = st.container() | |
# Store placeholder for channel name and ROI/MROI plots | |
channels_name_plot_placeholder[channel] = { | |
"channel_name_placeholder": channel_name_placeholder, | |
"channel_plot_placeholder": channel_plot_placeholder, | |
} | |
# Channel spends and sales | |
channel_spends_actual = ( | |
original_scenario_data["channels"][channel]["actual_total_spends"] | |
* original_scenario_data["channels"][channel]["conversion_rate"] | |
) | |
channel_metrics_actual = original_scenario_data["channels"][channel][ | |
"actual_total_sales" | |
] | |
channel_spends_modified = ( | |
modified_scenario_data["channels"][channel]["modified_total_spends"] | |
* original_scenario_data["channels"][channel]["conversion_rate"] | |
) | |
channel_metrics_modified = modified_scenario_data["channels"][channel][ | |
"modified_total_sales" | |
] | |
# Channel spends input | |
with input_col: | |
# Absolute Spends Input | |
st.session_state[f"{channel_key}_abs_spends_key"] = numerize( | |
modified_scenario_data["channels"][channel]["modified_total_spends"] | |
* original_scenario_data["channels"][channel]["conversion_rate"] | |
/ st.session_state["multiplier"] | |
) | |
absolute_channel_spends = st.text_input( | |
"Absolute Spends", | |
key=f"{channel_key}_abs_spends_key", | |
on_change=absolute_channel_spends_change, | |
args=( | |
channel_key, | |
channel_spends_actual, | |
channel, | |
metrics_selected, | |
panel_selected, | |
), | |
) | |
# Update Percentage Spends Input | |
st.session_state[f"{channel_key}_per_spends_key"] = int( | |
round( | |
( | |
( | |
convert_to_float( | |
st.session_state[f"{channel_key}_abs_spends_key"] | |
) | |
* st.session_state["multiplier"] | |
- float(channel_spends_actual) | |
) | |
/ channel_spends_actual | |
) | |
* 100, | |
0, | |
) | |
) | |
# Percentage Spends Input | |
percentage_channel_spends = st.number_input( | |
"Percentage Spends", | |
min_value=-1000, | |
max_value=1000, | |
key=f"{channel_key}_per_spends_key", | |
on_change=percentage_channel_spends_change, | |
args=( | |
channel_key, | |
channel_spends_actual, | |
channel, | |
metrics_selected, | |
panel_selected, | |
), | |
) | |
# Store channel spends, conversion ratio and proportion list | |
channels_spends[channel] = original_scenario_data["channels"][channel][ | |
"actual_total_spends" | |
] * (1 + percentage_channel_spends / 100) | |
channels_conversion_ratio[channel] = original_scenario_data["channels"][ | |
channel | |
]["conversion_rate"] | |
channels_proportion[channel] = original_scenario_data["channels"][ | |
channel | |
]["spends"] / sum(original_scenario_data["channels"][channel]["spends"]) | |
# Calculate the percent contribution of actual spends for the channel | |
channel_actual_spend_contribution = round( | |
( | |
modified_scenario_data["channels"][channel][ | |
"actual_total_spends" | |
] | |
* channels_conversion_ratio[channel] | |
/ total_actual_spend_contribution | |
) | |
* 100, | |
1, | |
) | |
# Calculate the percent contribution of actual metrics (sales) for the channel | |
channel_actual_metric_contribution = round( | |
( | |
modified_scenario_data["channels"][channel][ | |
"actual_total_sales" | |
] | |
/ total_actual_metric_contribution | |
) | |
* 100, | |
1, | |
) | |
# Calculate the percent contribution of optimized spends for the channel | |
channel_optimized_spend_contribution = round( | |
( | |
modified_scenario_data["channels"][channel][ | |
"modified_total_spends" | |
] | |
* channels_conversion_ratio[channel] | |
/ total_optimized_spend_contribution | |
) | |
* 100, | |
1, | |
) | |
# Calculate the percent contribution of optimized metrics (sales) for the channel | |
channel_optimized_metric_contribution = round( | |
( | |
modified_scenario_data["channels"][channel][ | |
"modified_total_sales" | |
] | |
/ total_optimized_metric_contribution | |
) | |
* 100, | |
1, | |
) | |
# Channel metrics display | |
with metrics_col: | |
# Absolute Metrics | |
st.metric( | |
f"Actual {name_formating(metrics_selected)}", | |
value=str( | |
numerize( | |
channel_metrics_actual / st.session_state["multiplier"] | |
) | |
) | |
+ f"({channel_actual_metric_contribution}%)", | |
) | |
# Optimized Metrics | |
optimized_metric = ( | |
channel_metrics_modified / st.session_state["multiplier"] | |
) | |
actual_metric = channel_metrics_actual / st.session_state["multiplier"] | |
delta_value = ( | |
channel_metrics_modified - channel_metrics_actual | |
) / st.session_state["multiplier"] | |
# Check if the delta is less than 0.1% in absolute terms | |
if ( | |
abs(delta_value) < 0.001 * actual_metric | |
): # 0.1% of the actual metric | |
delta_display = "0" | |
else: | |
delta_display = numerize(delta_value) | |
st.metric( | |
f"Optimized {name_formating(metrics_selected)}", | |
value=str(numerize(optimized_metric)) | |
+ f"({channel_optimized_metric_contribution}%)", | |
delta=delta_display, | |
) | |
# Channel spends display | |
with spends_col: | |
# Absolute Spends | |
st.metric( | |
"Actual Spend", | |
value=str( | |
numerize(channel_spends_actual / st.session_state["multiplier"]) | |
) | |
+ f"({channel_actual_spend_contribution}%)", | |
) | |
# Optimized Spends | |
optimized_spends = ( | |
channel_spends_modified / st.session_state["multiplier"] | |
) | |
actual_spends = channel_spends_actual / st.session_state["multiplier"] | |
delta_spends_value = ( | |
channel_spends_modified - channel_spends_actual | |
) / st.session_state["multiplier"] | |
# Check if the delta is less than 0.1% in absolute terms | |
if ( | |
abs(delta_spends_value) < 0.001 * actual_spends | |
): # 0.1% of the actual spend | |
delta_spends_display = "0" | |
else: | |
delta_spends_display = numerize(delta_spends_value) | |
st.metric( | |
"Optimized Spend", | |
value=str(numerize(optimized_spends)) | |
+ f"({channel_optimized_spend_contribution}%)", | |
delta=delta_spends_display, | |
) | |
# Channel allows optimize | |
with allow_col: | |
# Allow Optimize (Freeze) | |
st.write("#") # Padding | |
st.session_state[f"{channel_key}_allow_optimize_key"] = ( | |
modified_scenario_data["channels"][channel]["freeze"] | |
) | |
freeze = st.checkbox( | |
"Freeze", | |
key=f"{channel_key}_allow_optimize_key", | |
on_change=freeze_change, | |
args=( | |
metrics_selected, | |
panel_selected, | |
channel_key, | |
channel, | |
channel_list, | |
), | |
) | |
# If channel is frozen, set bounds to keep the spend unchanged | |
if freeze: | |
lower_bound, upper_bound = 0, 0 # Freeze the spend at current level | |
# Channel bounds input | |
if st.session_state["bound_type_key"]: | |
with bounds_input_col: | |
# Channel upper bound | |
st.session_state[f"{channel_key}_upper_key"] = ( | |
modified_scenario_data["channels"][channel]["bounds"] | |
)[1] | |
upper_bound = st.number_input( | |
"Upper bound (%)", | |
min_value=-100, | |
max_value=100, | |
key=f"{channel_key}_upper_key", | |
disabled=st.session_state[f"{channel_key}_allow_optimize_key"], | |
on_change=bound_change, | |
args=( | |
metrics_selected, | |
panel_selected, | |
channel_key, | |
channel, | |
), | |
) | |
# Channel lower bound | |
st.session_state[f"{channel_key}_lower_key"] = ( | |
modified_scenario_data["channels"][channel]["bounds"] | |
)[0] | |
lower_bound = st.number_input( | |
"Lower bound (%)", | |
min_value=-100, | |
max_value=100, | |
key=f"{channel_key}_lower_key", | |
disabled=st.session_state[f"{channel_key}_allow_optimize_key"], | |
on_change=bound_change, | |
args=( | |
metrics_selected, | |
panel_selected, | |
channel_key, | |
channel, | |
), | |
) | |
# Check if lower bound is greater than upper bound | |
if lower_bound > upper_bound: | |
lower_bound = -10 # Default lower bound | |
upper_bound = 10 # Default upper bound | |
# Store bounds | |
bounds_dict[channel] = [lower_bound, upper_bound] | |
else: | |
# If channel is frozen, set bounds to keep the spend unchanged | |
if freeze: | |
lower_bound, upper_bound = 0, 0 # Freeze the spend at current level | |
else: | |
lower_bound = -10 # Default lower bound | |
upper_bound = 10 # Default upper bound | |
# Store bounds | |
bounds_dict[channel] = modified_scenario_data["channels"][channel][ | |
"bounds" | |
] | |
# Display the bounds for each channel's spend in the bounds_display_col | |
with bounds_display_col: | |
# Retrieve the actual spends for the channel from the original scenario data | |
actual_spends = ( | |
modified_scenario_data["channels"][channel]["modified_total_spends"] | |
* modified_scenario_data["channels"][channel]["conversion_rate"] | |
) | |
# Calculate the limit for spends | |
upper_limit_spends = actual_spends * (1 + upper_bound / 100) | |
lower_limit_spends = actual_spends * (1 + lower_bound / 100) | |
# Display the upper limit spends | |
st.metric( | |
"Upper Bound", | |
numerize(upper_limit_spends / st.session_state["multiplier"]), | |
) | |
st.metric( | |
"Lower Bound", | |
numerize(lower_limit_spends / st.session_state["multiplier"]), | |
) | |
# Store S-curve parameters | |
s_curve_params[channel] = get_s_curve_params( | |
metrics_selected, | |
panel_selected, | |
channel, | |
original_rcs_data, | |
modified_rcs_data, | |
) | |
# Total channel spends | |
total_channel_spends += ( | |
convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"]) | |
* st.session_state["multiplier"] | |
) | |
# Check if total channel spends are within the allowed range (±50% of the original total spends) | |
if ( | |
total_channel_spends > 1.5 * original_scenario_data["actual_total_spends"] | |
or total_channel_spends | |
< 0.5 * original_scenario_data["actual_total_spends"] | |
): | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "warning", | |
"message": "Keep total spending within ±50% of the original value.", | |
"icon": "⚠️", | |
} | |
if optimization_goal == "Spend": | |
# Get maximum achievable spends | |
lower_achievable_target, upper_achievable_target = 0, 0 | |
for channel in channel_list: | |
channel_spends_actual = ( | |
channels_spends[channel] * channels_conversion_ratio[channel] | |
) | |
lower_achievable_target += channel_spends_actual * ( | |
1 + bounds_dict[channel][0] / 100 | |
) | |
upper_achievable_target += channel_spends_actual * ( | |
1 + bounds_dict[channel][1] / 100 | |
) | |
else: | |
# Get maximum achievable target metric | |
lower_achievable_target, upper_achievable_target = max_target_achievable( | |
channels_spends, | |
s_curve_params, | |
channels_proportion, | |
modified_scenario_data, | |
bounds_dict, | |
) | |
# Total target of selected metric | |
if optimization_goal == "Spend": | |
total_absolute_target = modified_scenario_data["modified_total_spends"] | |
else: | |
total_absolute_target = modified_scenario_data["modified_total_sales"] | |
# Check if the target is achievable within the specified bounds | |
if optimize_allow: | |
optimize_allow = check_target_achievability( | |
optimize_allow, | |
name_formating(optimization_goal), | |
lower_achievable_target, | |
upper_achievable_target, | |
total_absolute_target, | |
) | |
# Perform the optimization | |
if optimize_button_col.button( | |
"Optimize", | |
use_container_width=True, | |
disabled=not optimize_allow, | |
key="run_optimizer", | |
): | |
with message_display_col: | |
with spinner_placeholder, st.spinner("Optimizing ..."): | |
# Call the optimizer function to get optimized spends | |
optimized_spends, optimization_success = optimizer( | |
optimization_goal, | |
s_curve_params, | |
channels_spends, | |
channels_proportion, | |
channels_conversion_ratio, | |
total_absolute_target, | |
bounds_dict, | |
modified_scenario_data, | |
) | |
# Initialize dictionaries to store input and output channel spends | |
input_channels_spends, output_channels_spends = {}, {} | |
for channel in channel_list: | |
# Calculate input channel spends by converting spends using conversion ratio | |
input_channels_spends[channel] = ( | |
channels_spends[channel] * channels_conversion_ratio[channel] | |
) | |
# Calculate output channel spends by converting optimized spends using conversion ratio | |
output_channels_spends[channel] = ( | |
optimized_spends[channel] * channels_conversion_ratio[channel] | |
) | |
# Calculate total actual and modified spends | |
actual_total_spends = sum(list(input_channels_spends.values())) | |
modified_total_spends = sum(list(output_channels_spends.values())) | |
# Retrieve the actual total metrics from modified scenario data | |
actual_total_metrics = modified_scenario_data["modified_total_sales"] | |
modified_total_metrics = 0 # Initialize modified total metrics | |
modified_channels_metrics = {} | |
# Calculate modified metrics for each channel | |
for channel in optimized_spends.keys(): | |
channel_s_curve_params = s_curve_params[channel] | |
spend_proportion = ( | |
optimized_spends[channel] * channels_proportion[channel] | |
) | |
# Calculate the metrics using the S-curve function | |
modified_channels_metrics[channel] = sum( | |
s_curve( | |
spend_proportion, | |
channel_s_curve_params["power"], | |
channel_s_curve_params["K"], | |
channel_s_curve_params["b"], | |
channel_s_curve_params["a"], | |
channel_s_curve_params["x0"], | |
) | |
) + sum( | |
modified_scenario_data["channels"][channel]["correction"] | |
) # correction for s-curve | |
modified_total_metrics += modified_channels_metrics[ | |
channel | |
] # Add channel metrics to total metrics | |
# Add the constant and correction term to the modified total metrics | |
modified_total_metrics += sum(modified_scenario_data["constant"]) | |
# Retrieve the original total spends from modified scenario data | |
original_total_spends = modified_scenario_data["actual_total_spends"] | |
# Check the success of the optimization process | |
success, message, icon = check_optimization_success( | |
channel_list, | |
input_channels_spends, | |
output_channels_spends, | |
bounds_dict, | |
optimization_goal, | |
modified_total_metrics, | |
actual_total_metrics, | |
modified_total_spends, | |
actual_total_spends, | |
original_total_spends, | |
optimization_success, | |
) | |
# Store the message details in session state | |
st.session_state.message_display = { | |
"type": "success" if success else "error", | |
"message": message, | |
"icon": icon, | |
} | |
# Update data only if the optimization is successful | |
if success: | |
# Update the modified spend and metrics for each channel in the scenario data | |
for channel in channel_list: | |
modified_scenario_data["channels"][channel][ | |
"modified_total_spends" | |
] = optimized_spends[channel] | |
# Update the modified metrics for each channel in the scenario data | |
modified_scenario_data["channels"][channel][ | |
"modified_total_sales" | |
] = modified_channels_metrics[channel] | |
# Update the total modified spends in the scenario data | |
modified_scenario_data["modified_total_spends"] = ( | |
modified_total_spends | |
) | |
# Update the total modified metrics in the scenario data | |
modified_scenario_data["modified_total_sales"] = ( | |
modified_total_metrics | |
) | |
# Load modified scenario data | |
data = st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] | |
# Update the specific section with the modified scenario data | |
data[metrics_selected][panel_selected] = modified_scenario_data | |
# Update modified scenario metadata | |
st.session_state["project_dct"]["scenario_planner"][ | |
"modified_metadata_file" | |
] = data | |
# Reset optimizer button | |
del st.session_state["run_optimizer"] | |
# Rerun to update values | |
st.rerun() | |
########################################## Response Curves ########################################## | |
# Generate plots | |
figures, channel_roi_mroi, region_start_end = generate_response_curve_plots( | |
channel_list, | |
s_curve_params, | |
channels_proportion, | |
original_scenario_data, | |
st.session_state["multiplier"], | |
) | |
# Display Response Curves | |
st.subheader(f"Response Curves (X: Spends Vs Y: {metrics_selected_formatted})") | |
with st.expander("Response Curves", expanded=True): | |
cols = st.columns(4) # Create 4 columns for the first row | |
for i, fig in enumerate(figures): | |
col = cols[i % 4] # Rotate through the columns | |
with col: | |
# Get channel parameters | |
channel = channel_list[i] | |
modified_total_spends = modified_scenario_data["channels"][channel][ | |
"modified_total_spends" | |
] | |
conversion_rate = modified_scenario_data["channels"][channel][ | |
"conversion_rate" | |
] | |
channel_correction = sum( | |
modified_scenario_data["channels"][channel]["correction"] | |
) | |
# Updated figure with modified metrics point | |
roi_optimized, mroi_optimized, fig_updated = modified_metrics_point( | |
fig, | |
modified_total_spends, | |
s_curve_params[channel], | |
channels_proportion[channel], | |
conversion_rate, | |
channel_correction, | |
) | |
# Store data of each channel ROI and MROI | |
channel_roi_mroi[channel]["optimized_roi"] = roi_optimized | |
channel_roi_mroi[channel]["optimized_mroi"] = mroi_optimized | |
st.plotly_chart(fig_updated, use_container_width=True) | |
# Start a new row after every 4 plots | |
if (i + 1) % 4 == 0 and i + 1 < len(figures): | |
cols = st.columns(4) # Create new row with 4 columns | |
# Generate the plots | |
channel_roi_mroi_plot = roi_mori_plot(channel_roi_mroi) | |
# Display the plots and name with background color | |
for channel in channel_list: | |
with channels_name_plot_placeholder[channel]["channel_plot_placeholder"]: | |
# Create subplots with 2 columns for ROI and MROI | |
roi_plot_col, mroi_plot_col = st.columns(2) | |
# Display ROI and MROI plots | |
roi_plot_col.plotly_chart(channel_roi_mroi_plot[channel]["fig_roi"]) | |
mroi_plot_col.plotly_chart(channel_roi_mroi_plot[channel]["fig_mroi"]) | |
# Placeholder for the channel name | |
channel_name_placeholder = channels_name_plot_placeholder[channel][ | |
"channel_name_placeholder" | |
] | |
# Retrieve modified total spends and conversion rate for the channel | |
modified_total_spends = modified_scenario_data["channels"][channel][ | |
"modified_total_spends" | |
] | |
conversion_rate = modified_scenario_data["channels"][channel]["conversion_rate"] | |
# Calculate the actual spend value for the channel | |
channel_spends_value = modified_total_spends * conversion_rate | |
# Calculate the RGBA color value for the channel based on its spend | |
channel_rgba_value = calculate_rgba( | |
channel_spends_value, region_start_end[channel] | |
) | |
# Display the channel name with the calculated background color | |
channel_name_placeholder.markdown( | |
display_channel_name_with_background_color(channel, channel_rgba_value), | |
unsafe_allow_html=True, | |
) | |
# Input field for the scenario name | |
st.text_input("Scenario Name", key="scenario_name") | |
# Disable the "Save Scenario" button until a name is provided | |
if ( | |
st.session_state["scenario_name"] is None | |
or st.session_state["scenario_name"] == "" | |
): | |
save_scenario_button_disabled = True | |
else: | |
save_scenario_button_disabled = False | |
# Button to save the scenario | |
save_button_placeholder = st.empty() | |
with st.spinner("Saving ..."): | |
save_button_placeholder.button( | |
"Save Scenario", | |
on_click=save_scenario, | |
args=( | |
modified_scenario_data, | |
metrics_selected, | |
panel_selected, | |
optimization_goal, | |
channel_roi_mroi, | |
st.session_state["timeframe_selected_selectbox_sp"], | |
st.session_state["multiplier"], | |
), | |
disabled=save_scenario_button_disabled, | |
) | |
########################################## Display Message ########################################## | |
# Display all message | |
with message_display_col: | |
display_message() | |
except Exception as e: | |
# Capture the error details | |
exc_type, exc_value, exc_traceback = sys.exc_info() | |
error_message = "".join( | |
traceback.format_exception(exc_type, exc_value, exc_traceback) | |
) | |
# Log message | |
log_message("error", f"An error occurred: {error_message}.", "Scenario Planner") | |
# Display a warning message | |
st.warning( | |
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.", | |
icon="⚠️", | |
) | |