MediaMixOptimization / pages /9_Scenario_Planner.py
samkeet's picture
Upload 40 files
00b00eb verified
# Importing necessary libraries
import streamlit as st
st.set_page_config(
page_title="Scenario Planner",
page_icon="⚖️",
layout="wide",
initial_sidebar_state="collapsed",
)
# Disable +/- for number input
st.markdown(
"""
<style>
button.step-up {display: none;}
button.step-down {display: none;}
div[data-baseweb] {border-radius: 4px;}
</style>""",
unsafe_allow_html=True,
)
import re
import sys
import copy
import pickle
import traceback
import numpy as np
import pandas as pd
from scenario import numerize
import plotly.graph_objects as go
from post_gres_cred import db_cred
from scipy.optimize import minimize
from log_application import log_message
from utilities import project_selection, update_db, set_header, load_local_css
from utilities import (
get_panels_names,
get_metrics_names,
name_formating,
load_rcs_metadata_files,
load_scenario_metadata_files,
generate_rcs_data,
generate_scenario_data,
)
from constants import (
xtol_tolerance_per,
mroi_threshold,
word_length_limit_lower,
word_length_limit_upper,
)
schema = db_cred["schema"]
load_local_css("styles.css")
set_header()
# Initialize project name session state
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
# Fetch project dictionary
if "project_dct" not in st.session_state:
project_selection()
st.stop()
# Display Username and Project Name
if "username" in st.session_state and st.session_state["username"] is not None:
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
# Initialize ROI threshold
if "roi_threshold" not in st.session_state:
st.session_state.roi_threshold = 1
# Initialize message display holder
if "message_display" not in st.session_state:
st.session_state.message_display = {"type": "success", "message": None, "icon": ""}
# Function to reset modified_scenario_data
def reset_scenario(metrics_selected=None, panel_selected=None):
# Clear message_display
st.session_state.message_display = {"type": "success", "message": None, "icon": ""}
# Use default values from session state if not provided
if metrics_selected is None:
metrics_selected = st.session_state["response_metrics_selectbox_sp"]
if panel_selected is None:
panel_selected = st.session_state["panel_selected_selectbox_sp"]
# Load original scenario data
original_data = st.session_state["project_dct"]["scenario_planner"][
"original_metadata_file"
]
original_scenario_data = original_data[metrics_selected][panel_selected]
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Update the specific section with the original scenario data
data[metrics_selected][panel_selected] = copy.deepcopy(original_scenario_data)
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Function to build s curve
def s_curve(x, power, K, b, a, x0):
return K / (1 + b * np.exp(-a * ((x / 10**power) - x0)))
# Function to retrieve S-curve parameters for a given metric, panel, and channel
def get_s_curve_params(
metrics_selected,
panel_selected,
channel_selected,
original_rcs_data,
modified_rcs_data,
):
# Retrieve 'power' parameter from the original data for the specific metric, panel, and channel
power = original_rcs_data[metrics_selected][panel_selected][channel_selected][
"power"
]
# Get the S-curve parameters from the modified data for the same metric, panel, and channel
s_curve_param = modified_rcs_data[metrics_selected][panel_selected][
channel_selected
]
# Load modified scenario metadata
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Update modified S-curve parameters
data[metrics_selected][panel_selected]["channels"][channel_selected][
"response_curve_params"
] = s_curve_param
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Update the 'power' parameter in the modified S-curve parameters with the original 'power' value
s_curve_param["power"] = power
# Return the updated S-curve parameters
return s_curve_param
# Function to calculate total contribution
def get_total_contribution(
spends, channels, s_curve_params, channels_proportion, modified_scenario_data
):
total_contribution = 0
for i in range(len(channels)):
channel_name = channels[i]
channel_s_curve_params = s_curve_params[channel_name]
spend_proportion = spends[i] * channels_proportion[channel_name]
total_contribution += sum(
s_curve(
spend_proportion,
channel_s_curve_params["power"],
channel_s_curve_params["K"],
channel_s_curve_params["b"],
channel_s_curve_params["a"],
channel_s_curve_params["x0"],
)
) + sum(
modified_scenario_data["channels"][channel_name]["correction"]
) # correction for s-curve
return total_contribution + sum(modified_scenario_data["constant"])
# Function to calculate total spends
def get_total_spends(spends, channels_conversion_ratio):
return np.sum(spends * np.array(list(channels_conversion_ratio.values())))
# Function to optimizes spends for all channels given bounds and a total spend target
def optimizer(
optimization_goal,
s_curve_params,
channels_spends,
channels_proportion,
channels_conversion_ratio,
total_target,
bounds_dict,
modified_scenario_data,
):
# Extract channel names and corresponding actual spends
channels = list(channels_spends.keys())
actual_spends = np.array(list(channels_spends.values()))
num_channels = len(actual_spends)
# Define the objective function based on the optimization goal
def objective_fun(spends):
if optimization_goal == "Spend":
# Minimize negative total contribution to maximize the total contribution
return -get_total_contribution(
spends,
channels,
s_curve_params,
channels_proportion,
modified_scenario_data,
)
else:
# Minimize total spends
return get_total_spends(spends, channels_conversion_ratio)
def constraint_fun(spends):
if optimization_goal == "Spend":
# Ensure the total spends equals the total spend target
return get_total_spends(spends, channels_conversion_ratio)
else:
# Ensure the total contribution equals the total contribution target
return get_total_contribution(
spends,
channels,
s_curve_params,
channels_proportion,
modified_scenario_data,
)
# Equality constraint
constraints = {
"type": "eq",
"fun": lambda spends: constraint_fun(spends) - total_target,
} # Sum of all channel spends/metrics should equal the total spend/metrics target
# Bounds for each channel's spend based
bounds = [
(
actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100),
actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100),
)
for i in range(num_channels)
]
# Initial guess for the optimization
initial_guess = np.array(actual_spends)
# Calculate xtol as n% of the minimum of spends
xtol = max(10, (xtol_tolerance_per / 100) * np.min(actual_spends))
# Perform the optimization using 'trust-constr' method
result = minimize(
objective_fun,
initial_guess,
method="trust-constr",
constraints=constraints,
bounds=bounds,
options={
"disp": True, # Display the optimization process
"xtol": xtol, # Dynamic step size tolerance
"maxiter": 1e5, # Maximum number of iterations
},
)
# Extract the optimized spends from the result
optimized_spends_array = result.x
# Convert optimized spends back to a dictionary with channel names
optimized_spends = {
channels[i]: max(0, optimized_spends_array[i]) for i in range(num_channels)
}
return optimized_spends, result.success
# Function to calculate achievable targets at lower and upper spend bounds
@st.cache_data(show_spinner=False)
def max_target_achievable(
channels_spends,
s_curve_params,
channels_proportion,
modified_scenario_data,
bounds_dict,
):
# Extract channel names and corresponding actual spends
channels = list(channels_spends.keys())
actual_spends = np.array(list(channels_spends.values()))
num_channels = len(actual_spends)
# Bounds for each channel's spend
lower_spends, upper_spends = [], []
for i in range(num_channels):
lower_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100))
upper_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100))
# Calculate achievable targets at lower and upper spend bounds
lower_achievable_target = get_total_contribution(
lower_spends,
channels,
s_curve_params,
channels_proportion,
modified_scenario_data,
)
upper_achievable_target = get_total_contribution(
upper_spends,
channels,
s_curve_params,
channels_proportion,
modified_scenario_data,
)
# Return achievable targets with ±0.1% safety margin
return max(0, 1.001 * lower_achievable_target), 0.999 * upper_achievable_target
# Function to check if number is in valid format
def is_valid_number_format(number_str):
# Check for None
if number_str is None:
# Store the message details in session state for invalid input
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
# Define the valid suffixes
valid_suffixes = {"K", "M", "B", "T"}
# Check for negative numbers
if number_str[0] == "-":
# Store the message details in session state for invalid input
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
# Check if the string ends with a digit
if number_str[-1].isdigit():
try:
# Attempt to convert the entire string to float
number = float(number_str)
# Ensure the number is non-negative
if number >= 0:
return True
else:
# Store the message details in session state for invalid input
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
except ValueError:
# Store the message details in session state for invalid input
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
# Check if the string ends with a valid suffix
suffix = number_str[-1].upper()
if suffix in valid_suffixes:
num_part = number_str[:-1] # Extract the numerical part
try:
# Attempt to convert the numerical part to float
number = float(num_part)
# Ensure the number part is non-negative
if number >= 0:
return True
else:
# Store the message details in session state for invalid input
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
except ValueError:
# Store the message details in session state for invalid input
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
# If neither condition is met, return False
st.session_state.message_display = {
"type": "warning",
"message": "Invalid input: Please enter a valid number.",
"icon": "⚠️",
}
return False
# Function to converts a string with number suffixes (K, M, B, T) to a float
def convert_to_float(number_str):
# Dictionary mapping suffixes to their multipliers
multipliers = {
"K": 1e3, # Thousand
"M": 1e6, # Million
"B": 1e9, # Billion
"T": 1e12, # Trillion
}
# If there's no suffix, directly convert to float
if number_str[-1].isdigit():
return float(number_str)
# Extract the suffix (last character) and the numerical part
suffix = number_str[-1].upper()
num_part = number_str[:-1]
# Convert the numerical part to float and multiply by the corresponding multiplier
return float(num_part) * multipliers[suffix]
# Function to update absolute_channel_spends change
def absolute_channel_spends_change(
channel_key, channel_spends_actual, channel, metrics_selected, panel_selected
):
# Do not update if the number is in an invalid format
if not is_valid_number_format(st.session_state[f"{channel_key}_abs_spends_key"]):
return
# Get updated absolute spends from session state
new_absolute_spends = (
convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"])
* st.session_state["multiplier"]
)
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Total channel spends
total_channel_spends = 0
for current_channel in list(
data[metrics_selected][panel_selected]["channels"].keys()
):
# Channel key
channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}"
total_channel_spends += (
convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"])
* st.session_state["multiplier"]
)
# Check if total channel spends are within the allowed range (±50% of the original total spends)
if (
total_channel_spends
< 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
and total_channel_spends
> 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
):
# Update the modified_total_spends for the specified channel
data[metrics_selected][panel_selected]["channels"][channel][
"modified_total_spends"
] = new_absolute_spends / float(
data[metrics_selected][panel_selected]["channels"][channel][
"conversion_rate"
]
)
# Update total spends
data[metrics_selected][panel_selected][
"modified_total_spends"
] = total_channel_spends
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = data
else:
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Keep total spending within ±50% of the original value.",
"icon": "⚠️",
}
# Function to update percentage_channel_spends change
def percentage_channel_spends_change(
channel_key, channel_spends_actual, channel, metrics_selected, panel_selected
):
# Retrieve the percentage spend change from session state
percentage_channel_spends = round(
st.session_state[f"{channel_key}_per_spends_key"], 0
)
# Calculate the new absolute spends
new_absolute_spends = channel_spends_actual * (1 + percentage_channel_spends / 100)
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Total channel spends
total_channel_spends = 0
for current_channel in list(
data[metrics_selected][panel_selected]["channels"].keys()
):
# Channel key
channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}"
# Current channel spends actual
current_channel_spends_actual = data[metrics_selected][panel_selected][
"channels"
][current_channel]["actual_total_spends"]
# Current channel conversion rate
current_channel_conversion_rate = data[metrics_selected][panel_selected][
"channels"
][current_channel]["conversion_rate"]
# Calculate the current channel absolute spends
current_channel_absolute_spends = (
current_channel_spends_actual
* current_channel_conversion_rate
* (1 + st.session_state[f"{channel_key}_per_spends_key"] / 100)
)
total_channel_spends += current_channel_absolute_spends
# Check if total channel spends are within the allowed range (±50% of the original total spends)
if (
total_channel_spends
< 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
and total_channel_spends
> 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
):
# Update the modified_total_spends for the specified channel
data[metrics_selected][panel_selected]["channels"][channel][
"modified_total_spends"
] = float(new_absolute_spends) / float(
data[metrics_selected][panel_selected]["channels"][channel][
"conversion_rate"
]
)
# Update total spends
data[metrics_selected][panel_selected][
"modified_total_spends"
] = total_channel_spends
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = data
else:
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Keep total spending within ±50% of the original value.",
"icon": "⚠️",
}
# # Function to update total input change
# def total_input_change(per_change):
# # Load modified scenario data
# data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# # Get the list of all channels in the specified panel and metric
# channel_list = list(data[metrics_selected][panel_selected]["channels"].keys())
# # Iterate over each channel to update their modified spends
# for channel in channel_list:
# # Retrieve the actual spends for the channel
# channel_actual_spends = data[metrics_selected][panel_selected]["channels"][
# channel
# ]["actual_total_spends"]
# # Calculate the modified spends for the channel based on the percent change
# modified_channel_metrics = channel_actual_spends * ((100 + per_change) / 100)
# # Update the channel's modified total spends in the data
# data[metrics_selected][panel_selected]["channels"][channel][
# "modified_total_spends"
# ] = modified_channel_metrics
# # Update modified scenario metadata
# st.session_state["project_dct"]["scenario_planner"][
# "modified_metadata_file"
# ] = data
# Function to update total input change
def total_input_change(per_change, metrics_selected, panel_selected):
# Load modified and original scenario data
modified_data = st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
].copy()
original_data = st.session_state["project_dct"]["scenario_planner"][
"original_metadata_file"
].copy()
# Get the list of all channels in the selected panel and metric
channel_list = list(
modified_data[metrics_selected][panel_selected]["channels"].keys()
)
# Separate channels into unfrozen and frozen based on optimization status
unfrozen_channels, frozen_channels = [], []
for channel in channel_list:
channel_key = f"{metrics_selected}_{panel_selected}_{channel}"
if st.session_state.get(f"{channel_key}_allow_optimize_key", False):
frozen_channels.append(channel)
else:
unfrozen_channels.append(channel)
# Calculate spends and total share from frozen channels, weighted by conversion rate
frozen_channel_share, frozen_channel_spends = 0, 0
for channel in frozen_channels:
conversion_rate = original_data[metrics_selected][panel_selected]["channels"][
channel
]["conversion_rate"]
actual_spends = original_data[metrics_selected][panel_selected]["channels"][
channel
]["actual_total_spends"]
modified_spends = modified_data[metrics_selected][panel_selected]["channels"][
channel
]["modified_total_spends"]
spends_diff = max(actual_spends, 1e-3) * ((100 + per_change) / 100) - max(
modified_spends, 1e-3
)
frozen_channel_share += spends_diff * conversion_rate
frozen_channel_spends += max(actual_spends, 1e-3) * conversion_rate
# Redistribute frozen share across unfrozen channels based on original spend weights
for channel in unfrozen_channels:
conversion_rate = original_data[metrics_selected][panel_selected]["channels"][
channel
]["conversion_rate"]
actual_spends = original_data[metrics_selected][panel_selected]["channels"][
channel
]["actual_total_spends"]
# Calculate weight of the current channel's original spends
total_original_spends = original_data[metrics_selected][panel_selected][
"actual_total_spends"
]
channel_weight = (actual_spends * conversion_rate) / (
total_original_spends - frozen_channel_spends
)
# Calculate the modified spends with redistributed frozen share
modified_spends = (
max(actual_spends, 1e-3) * ((100 + per_change) / 100)
+ (frozen_channel_share * channel_weight) / conversion_rate
)
# Update modified total spends in the modified data
modified_data[metrics_selected][panel_selected]["channels"][channel][
"modified_total_spends"
] = modified_spends
# Save the updated modified scenario data back to the session state
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = modified_data
# Function to update total_absolute_main_key change
def total_absolute_main_key_change(metrics_selected, panel_selected, optimization_goal):
# Do not update if the number is in an invalid format
if not is_valid_number_format(st.session_state["total_absolute_main_key"]):
return
# Get updated absolute from session state
new_absolute = (
convert_to_float(st.session_state["total_absolute_main_key"])
* st.session_state["multiplier"]
)
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
if optimization_goal == "Spend":
# Retrieve the old absolute spends
old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"]
else:
# Retrieve the old absolute metrics
old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"]
# Calculate the allowable range for new spends
lower_bound = old_absolute * 0.5
upper_bound = old_absolute * 1.5
# Ensure the new spends are within ±50% of the old value
if new_absolute < lower_bound or new_absolute > upper_bound:
new_absolute = old_absolute
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Keep total spending within ±50% of the original value.",
"icon": "⚠️",
}
if optimization_goal == "Spend":
# Update the modified_total_spends with the constrained value
data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute
else:
# Update the modified_total_sales with the constrained value
data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Update total input change
if optimization_goal == "Spend":
per_change = ((new_absolute - old_absolute) / old_absolute) * 100
total_input_change(per_change, metrics_selected, panel_selected)
# Function to update total_absolute_key change
def total_absolute_key_change(metrics_selected, panel_selected, optimization_goal):
# Get updated absolute from session state
new_absolute = (
convert_to_float(st.session_state["total_absolute_key"])
* st.session_state["multiplier"]
)
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
if optimization_goal == "Spend":
# Update the modified_total_spends for the specified channel
data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute
old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"]
else:
# Update the modified_total_sales for the specified channel
data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute
old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"]
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Update total input change
if optimization_goal == "Spend":
per_change = ((new_absolute - old_absolute) / old_absolute) * 100
total_input_change(per_change, metrics_selected, panel_selected)
# Function to update total_absolute_key change
def total_percentage_key_change(
metrics_selected,
panel_selected,
absolute_value,
optimization_goal,
):
# Get updated absolute from session state
new_absolute = absolute_value * (1 + st.session_state["total_percentage_key"] / 100)
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
if optimization_goal == "Spend":
# Update the modified_total_spends for the specified channel
data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute
old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"]
else:
# Update the modified_total_sales for the specified channel
data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute
old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"]
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Update total input change
if optimization_goal == "Spend":
per_change = ((new_absolute - old_absolute) / old_absolute) * 100
total_input_change(per_change, metrics_selected, panel_selected)
# Function to update bound change
def bound_change(metrics_selected, panel_selected, channel_key, channel):
# Get updated bounds from session state
new_lower_bound = st.session_state[f"{channel_key}_lower_key"]
new_upper_bound = st.session_state[f"{channel_key}_upper_key"]
if new_lower_bound > new_upper_bound:
new_bounds = [-10, 10]
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Lower bound cannot be greater than Upper bound.",
"icon": "⚠️",
}
else:
new_bounds = [new_lower_bound, new_upper_bound]
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Update the bounds for the specified channel
data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Function to update freeze change
def freeze_change(metrics_selected, panel_selected, channel_key, channel, channel_list):
# Initialize counter for channels that are not frozen
unfrozen_channel_count = 0
# Check the optimization status of each channel
for current_channel in channel_list:
current_channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}"
unfrozen_channel_count += (
1
if not st.session_state[f"{current_channel_key}_allow_optimize_key"]
else 0
)
# Ensure at least two channels are allowed for optimization
if unfrozen_channel_count < 2:
st.session_state.message_display = {
"type": "warning",
"message": "Please allow at least two channels to be optimized.",
"icon": "⚠️",
}
return
if st.session_state[f"{channel_key}_allow_optimize_key"]:
# Updated bounds from session state
new_lower_bound, new_upper_bound = 0, 0
new_bounds = [new_lower_bound, new_upper_bound]
new_freeze = True
else:
# Updated bounds from session state
new_lower_bound, new_upper_bound = -10, 10
new_bounds = [new_lower_bound, new_upper_bound]
new_freeze = False
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Update the bounds for the specified channel
data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds
data[metrics_selected][panel_selected]["channels"][channel]["freeze"] = new_freeze
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Function to calculate y, ROI and MROI for given point
def get_point_parms(
x_val,
current_s_curve_params,
current_channel_proportion,
current_conversion_rate,
channel_correction,
):
# Calculate y value for the given spend point
y_val = (
sum(
s_curve(
(x_val * current_channel_proportion),
current_s_curve_params["power"],
current_s_curve_params["K"],
current_s_curve_params["b"],
current_s_curve_params["a"],
current_s_curve_params["x0"],
)
)
+ channel_correction
)
# Calculate MROI using a small nudge for actual spends
nudge = 1e-3
x1 = float(x_val * current_conversion_rate)
y1 = float(y_val)
x2 = x1 + nudge
y2 = (
sum(
s_curve(
((x2 / current_conversion_rate) * current_channel_proportion),
current_s_curve_params["power"],
current_s_curve_params["K"],
current_s_curve_params["b"],
current_s_curve_params["a"],
current_s_curve_params["x0"],
)
)
+ channel_correction
)
mroi_val = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0
# Calculate ROI
roi_val = y_val / (x_val * current_conversion_rate)
return roi_val, mroi_val, y_val
# Function to find segment value
def find_segment_value(x, roi, mroi, roi_threshold=1, mroi_threshold=0.05):
# Initialize the start and end values of the x array
start_value = x[0]
end_value = x[-1]
# Define the condition for the "green region" where both ROI and MROI exceed their respective thresholds
green_condition = (roi > roi_threshold) & (mroi > mroi_threshold)
# Find indices where ROI exceeds the ROI threshold
left_indices = np.where(roi > roi_threshold)[0]
# Find indices where both ROI and MROI exceed their thresholds (green condition)
right_indices = np.where(green_condition)[0]
# Determine the left value based on the first index where ROI exceeds the threshold
left_value = x[left_indices[0]] if left_indices.size > 0 else x[0]
# Determine the right value based on the last index where both ROI and MROI exceed their thresholds
right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0]
# Ensure the left value does not exceed the right value, adjust if necessary
if left_value > right_value:
left_value = right_value
return start_value, end_value, left_value, right_value
# Function to generate response curves plots
@st.cache_data(show_spinner=False)
def generate_response_curve_plots(
channel_list,
s_curve_params,
channels_proportion,
original_scenario_data,
multiplier,
):
figures, channel_roi_mroi, region_start_end = [], {}, {}
for channel in channel_list:
spends_actual = original_scenario_data["channels"][channel][
"actual_total_spends"
]
conversion_rate = original_scenario_data["channels"][channel]["conversion_rate"]
channel_correction = sum(
original_scenario_data["channels"][channel]["correction"]
)
x_actual = np.linspace(0, 5 * spends_actual, 100)
x_plot = x_actual * conversion_rate
# Calculate y values for the S-curve
y_plot = [
sum(
s_curve(
(x * channels_proportion[channel]),
s_curve_params[channel]["power"],
s_curve_params[channel]["K"],
s_curve_params[channel]["b"],
s_curve_params[channel]["a"],
s_curve_params[channel]["x0"],
)
)
+ channel_correction
for x in x_actual
]
# Calculate ROI and ensure they are scalar values
roi = [float(y) / float(x) if x != 0 else 0 for x, y in zip(x_plot, y_plot)]
# Calculate MROI using a small nudge
nudge = 1e-3
mroi = []
for i in range(len(x_plot)):
x1 = float(x_plot[i])
y1 = float(y_plot[i])
x2 = x1 + nudge
y2 = (
sum(
s_curve(
((x2 / conversion_rate) * channels_proportion[channel]),
s_curve_params[channel]["power"],
s_curve_params[channel]["K"],
s_curve_params[channel]["b"],
s_curve_params[channel]["a"],
s_curve_params[channel]["x0"],
)
)
+ channel_correction
)
mroi_value = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0
mroi.append(mroi_value)
# Channel correction
channel_correction = sum(
original_scenario_data["channels"][channel]["correction"]
)
# Calculate y, ROI and MROI for the actual spend point
roi_actual, mroi_actual, y_actual = get_point_parms(
spends_actual,
s_curve_params[channel],
channels_proportion[channel],
conversion_rate,
channel_correction,
)
# Create the plotly figure
fig = go.Figure()
# Add S-curve line
fig.add_trace(
go.Scatter(
x=np.array(x_plot) / multiplier,
y=np.array(y_plot) / multiplier,
mode="lines",
name="Metrics",
hoverinfo="text",
text=[
f"Spends: {numerize(x / multiplier)}<br>{metrics_selected_formatted}: {numerize(y / multiplier)}<br>ROI: {r:.2f}<br>MROI: {m:.2f}"
for x, y, r, m in zip(x_plot, y_plot, roi, mroi)
],
)
)
# Add current spend point
fig.add_trace(
go.Scatter(
x=[spends_actual * conversion_rate / multiplier],
y=[y_actual / multiplier],
mode="markers",
marker=dict(color="cyan", size=10, symbol="circle"),
name="Actual Spend",
hoverinfo="text",
text=[
f"Actual Spend: {numerize(spends_actual * conversion_rate / multiplier)}<br>{metrics_selected_formatted}: {numerize(y_actual / multiplier)}<br>ROI: {roi_actual:.2f}<br>MROI: {mroi_actual:.2f}"
],
showlegend=True,
)
)
# ROI Threshold
roi_threshold = st.session_state.roi_threshold
# Scale x and y values
x, y = np.array(x_plot), np.array(y_plot)
x_scaled, y_scaled = x / max(x), y / max(y)
# Calculate MROI scaled starting from the first point
mroi_scaled = np.zeros_like(x_scaled)
for j in range(1, len(x_scaled)):
x1, y1 = x_scaled[j - 1], y_scaled[j - 1]
x2, y2 = x_scaled[j], y_scaled[j]
mroi_scaled[j] = (y2 - y1) / (x2 - x1) if (x2 - x1) != 0 else 0
# Get the start_value, end_value, left_value, right_value for segments
start_value, end_value, left_value, right_value = find_segment_value(
x_plot, np.array(roi), mroi_scaled, roi_threshold, mroi_threshold
)
# Store region start and end points
region_start_end[channel] = {
"start_value": start_value,
"end_value": end_value,
"left_value": left_value,
"right_value": right_value,
}
# Adding background colors
y_max = max(y_plot) * 1.3 # 30% extra space above the max
# Yellow region
fig.add_shape(
type="rect",
x0=start_value / multiplier,
y0=0,
x1=left_value / multiplier,
y1=y_max / multiplier,
line=dict(width=0),
fillcolor="rgba(255, 255, 0, 0.3)",
layer="below",
)
# Green region
fig.add_shape(
type="rect",
x0=left_value / multiplier,
y0=0,
x1=right_value / multiplier,
y1=y_max / multiplier,
line=dict(width=0),
fillcolor="rgba(0, 255, 0, 0.3)",
layer="below",
)
# Red region
fig.add_shape(
type="rect",
x0=right_value / multiplier,
y0=0,
x1=end_value / multiplier,
y1=y_max / multiplier,
line=dict(width=0),
fillcolor="rgba(255, 0, 0, 0.3)",
layer="below",
)
# Layout adjustments
fig.update_layout(
title=f"{name_formating(channel)}",
showlegend=False,
xaxis=dict(
showgrid=True,
showticklabels=True,
tickformat=".2s",
gridcolor="lightgrey",
gridwidth=0.5,
griddash="dot",
),
yaxis=dict(
showgrid=True,
showticklabels=True,
tickformat=".2s",
gridcolor="lightgrey",
gridwidth=0.5,
griddash="dot",
),
template="plotly_white",
margin=dict(l=20, r=20, t=30, b=20),
height=100 * (len(channel_list) + 4 - 1) // 4,
)
figures.append(fig)
# Store data of each channel ROI and MROI
channel_roi_mroi[channel] = {
"actual_roi": roi_actual,
"actual_mroi": mroi_actual,
}
return figures, channel_roi_mroi, region_start_end
# Function to add modified spends/metrics point on plot
def modified_metrics_point(
fig,
modified_spends,
s_curve_params,
channels_proportion,
conversion_rate,
channel_correction,
):
# Calculate ROI, MROI, and y for the modified point
roi_modified, mroi_modified, y_modified = get_point_parms(
modified_spends,
s_curve_params,
channels_proportion,
conversion_rate,
channel_correction,
)
# Add modified spend point
fig.add_trace(
go.Scatter(
x=[modified_spends * conversion_rate / st.session_state["multiplier"]],
y=[y_modified / st.session_state["multiplier"]],
mode="markers",
marker=dict(color="blueviolet", size=10, symbol="circle"),
name="Optimized Spend",
hoverinfo="text",
text=[
f"Modified Spend: {numerize(modified_spends * conversion_rate / st.session_state.multiplier)}<br>{metrics_selected_formatted}: {numerize(y_modified / st.session_state.multiplier)}<br>ROI: {roi_modified:.2f}<br>MROI: {mroi_modified:.2f}"
],
showlegend=True,
)
)
return roi_modified, mroi_modified, fig
# Function to update bound type change
def bound_type_change():
# Get updated bound type from session state
new_bound_type = st.session_state["bound_type_key"]
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
# Update the bound type
data[metrics_selected][panel_selected]["bound_type"] = new_bound_type
# Set bounds to default value if bound type is False (Default)
channel_list = list(data[metrics_selected][panel_selected]["channels"].keys())
if not new_bound_type:
for channel in channel_list:
data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = [
-10,
10,
]
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
# Function to format the numbers with decimal
def format_value(input_value):
value = abs(input_value)
return f"{input_value:.4f}" if value < 1 else f"{numerize(input_value, 1)}"
# Function to format the numbers with decimal
def round_value(input_value):
value = abs(input_value)
return round(input_value, 4) if value < 1 else round(input_value, 1)
# Function to generate ROI and MROI plots for all channels
@st.cache_data(show_spinner=False)
def roi_mori_plot(channel_roi_mroi):
# Dictionary to store plots
channel_roi_mroi_plot = {}
for channel in channel_roi_mroi:
channel_roi_mroi_data = channel_roi_mroi[channel]
# Extract the data
actual_roi = channel_roi_mroi_data["actual_roi"]
optimized_roi = channel_roi_mroi_data["optimized_roi"]
actual_mroi = channel_roi_mroi_data["actual_mroi"]
optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
# Plot ROI
fig_roi = go.Figure()
fig_roi.add_trace(
go.Bar(
x=["Actual ROI"],
y=[actual_roi],
name="Actual ROI",
marker_color="cyan",
width=1,
text=[format_value(actual_roi)],
textposition="auto",
textfont=dict(color="black", size=14),
)
)
fig_roi.add_trace(
go.Bar(
x=["Optimized ROI"],
y=[optimized_roi],
name="Optimized ROI",
marker_color="blueviolet",
width=1,
text=[format_value(optimized_roi)],
textposition="auto",
textfont=dict(color="black", size=14),
)
)
fig_roi.update_layout(
annotations=[
dict(
x=0.5,
y=1.3,
xref="paper",
yref="paper",
text="ROI",
showarrow=False,
font=dict(size=14),
)
],
barmode="group",
bargap=0,
showlegend=False,
width=110,
height=110,
xaxis=dict(
showticklabels=True,
showgrid=False,
tickangle=0,
ticktext=["Actual", "Optimized"],
tickvals=["Actual ROI", "Optimized ROI"],
),
yaxis=dict(showticklabels=False, showgrid=False),
margin=dict(t=20, b=20, r=0, l=0),
)
# Plot MROI
fig_mroi = go.Figure()
fig_mroi.add_trace(
go.Bar(
x=["Actual MROI"],
y=[actual_mroi],
name="Actual MROI",
marker_color="cyan",
width=1,
text=[format_value(actual_mroi)],
textposition="auto",
textfont=dict(color="black", size=14),
)
)
fig_mroi.add_trace(
go.Bar(
x=["Optimized MROI"],
y=[optimized_mroi],
name="Optimized MROI",
marker_color="blueviolet",
width=1,
text=[format_value(optimized_mroi)],
textposition="auto",
textfont=dict(color="black", size=14),
)
)
fig_mroi.update_layout(
annotations=[
dict(
x=0.5,
y=1.3,
xref="paper",
yref="paper",
text="MROI",
showarrow=False,
font=dict(size=14),
)
],
barmode="group",
bargap=0,
showlegend=False,
width=110,
height=110,
xaxis=dict(
showticklabels=True,
showgrid=False,
tickangle=0,
ticktext=["Actual", "Optimized"],
tickvals=["Actual MROI", "Optimized MROI"],
),
yaxis=dict(showticklabels=False, showgrid=False),
margin=dict(t=20, b=20, r=0, l=0),
)
# Store plots
channel_roi_mroi_plot[channel] = {"fig_roi": fig_roi, "fig_mroi": fig_mroi}
return channel_roi_mroi_plot
# Function to save the current scenario with the mentioned name
def save_scenario(
scenario_dict,
metrics_selected,
panel_selected,
optimization_goal,
channel_roi_mroi,
timeframe,
multiplier,
):
# Remove extra space at start and ends
if st.session_state["scenario_name"] is not None:
st.session_state["scenario_name"] = st.session_state["scenario_name"].strip()
if (
st.session_state["scenario_name"] is None
or st.session_state["scenario_name"] == ""
):
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Please provide a name to save the scenario.",
"icon": "⚠️",
}
return
# Check the scenario name
if not (
word_length_limit_lower
<= len(st.session_state["scenario_name"])
<= word_length_limit_upper
and bool(re.match("^[A-Za-z0-9_]*$", st.session_state["scenario_name"]))
):
# Store the warning message details in session state
st.session_state.message_display = {
"type": "warning",
"message": f"Please provide a valid scenario name ({word_length_limit_lower}-{word_length_limit_upper} characters, only A-Z, a-z, 0-9, and _).",
"icon": "⚠️",
}
return
# Check if the dictionary is empty
if not scenario_dict:
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Nothing to save. The scenario data is empty.",
"icon": "⚠️",
}
return
# Add additional scenario details
scenario_dict["panel_selected"] = panel_selected
scenario_dict["metrics_selected"] = metrics_selected
scenario_dict["optimization"] = optimization_goal
scenario_dict["channel_roi_mroi"] = channel_roi_mroi
scenario_dict["timeframe"] = timeframe
scenario_dict["multiplier"] = multiplier
# Load existing scenarios
saved_scenarios_dict = st.session_state["project_dct"]["saved_scenarios"][
"saved_scenarios_dict"
]
# Check if the name is already taken
if st.session_state["scenario_name"] in saved_scenarios_dict.keys():
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Name already exists. Please change the name or delete the existing scenario from the Saved Scenario page.",
"icon": "⚠️",
}
return
# Update the dictionary with the new scenario
saved_scenarios_dict[st.session_state["scenario_name"]] = scenario_dict
# Update the updated dictionary
st.session_state["project_dct"]["saved_scenarios"][
"saved_scenarios_dict"
] = saved_scenarios_dict
# Update DB
update_db(
prj_id=st.session_state["project_number"],
page_nam="Scenario Planner",
file_nam="project_dct",
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
schema=schema,
)
# Store the message details in session state
st.session_state.message_display = {
"type": "success",
"message": f"Scenario '{st.session_state.scenario_name}' has been successfully saved!",
"icon": "💾",
}
st.toast(
f"Scenario '{st.session_state.scenario_name}' has been successfully saved!",
icon="💾",
)
# Clear the scenario name input
st.session_state["scenario_name"] = ""
# Function to calculate the RGBA color code based on the spends value and region boundaries
def calculate_rgba(spends_value, region_start_end):
# Get region start and end points
start_value = region_start_end["start_value"]
end_value = region_start_end["end_value"]
left_value = region_start_end["left_value"]
right_value = region_start_end["right_value"]
# Calculate alpha dynamically based on the position within the range
def calculate_alpha(position, start, end, min_alpha=0.1, max_alpha=0.4):
return min_alpha + (max_alpha - min_alpha) * (position - start) / (end - start)
if start_value <= spends_value <= left_value:
# Yellow range (0, 128, 0) - More transparent towards left, darker towards start
alpha = calculate_alpha(spends_value, left_value, start_value)
return (255, 255, 0, alpha) # RGB for yellow
elif left_value < spends_value <= right_value:
# Green range (0, 128, 0) - More transparent towards right, darker towards left
alpha = calculate_alpha(spends_value, right_value, left_value)
return (0, 128, 0, alpha) # RGB for green
elif right_value < spends_value <= end_value:
# Red range (255, 0, 0) - More transparent towards right, darker towards end
alpha = calculate_alpha(spends_value, right_value, end_value)
return (255, 0, 0, alpha) # RGB for red
# Function to format and display the channel name with a color and background color
def display_channel_name_with_background_color(
channel_name, background_color=(0, 128, 0, 0.1)
):
formatted_name = name_formating(channel_name)
# Unpack the RGBA values
r, g, b, a = background_color
# Create the HTML content with specified background color
html_content = f"""
<div style="
background-color: rgba({r}, {g}, {b}, {a});
padding: 10px;
display: inline-block;
border-radius: 5px;">
<strong>{formatted_name}</strong>
</div>
"""
return html_content
# Function to check optimization success
def check_optimization_success(
channel_list,
input_channels_spends,
output_channels_spends,
bounds_dict,
optimization_goal,
modified_total_metrics,
actual_total_metrics,
modified_total_spends,
actual_total_spends,
original_total_spends,
optimization_success,
):
for channel in channel_list:
input_channel_spends = input_channels_spends[channel]
output_channel_spends = abs(output_channels_spends[channel])
lower_percent = bounds_dict[channel][0]
upper_percent = bounds_dict[channel][1]
lower_allowed_value = max(
(input_channel_spends * (100 + lower_percent - 1) / 100), 0
) # 1% Tolerance
upper_allowed_value = max(
(input_channel_spends * (100 + upper_percent + 1) / 100), 10
) # 1% Tolerance
# Check if output spends are within allowed bounds
if (
output_channel_spends > upper_allowed_value
or output_channel_spends < lower_allowed_value
):
error_message = "Optimization failed: strict bounds. Use flexible bounds."
return False, error_message, "❌"
# Check optimization goal and percent change
if optimization_goal == "Spend":
percent_change_happened = abs(
(modified_total_spends - actual_total_spends) / actual_total_spends
)
if percent_change_happened > 0.01: # Greater than 1% Tolerance
error_message = "Optimization failed: input and optimized spends differ. Use flexible bounds."
return False, error_message, "❌"
else:
percent_change_happened = abs(
(modified_total_metrics - actual_total_metrics) / actual_total_metrics
)
if percent_change_happened > 0.01: # Greater than 1% Tolerance
error_message = "Optimization failed: input and optimized metrics differ. Use flexible bounds."
return False, error_message, "❌"
# Define the allowable range for new spends
lower_limit = original_total_spends * 0.5
upper_limit = original_total_spends * 1.5
# Check if the new spends are within the allowed range
if modified_total_spends < lower_limit or modified_total_spends > upper_limit:
error_message = "New spends optimized are outside the allowed range of ±50%."
return False, error_message, "❌"
# Check if the optimization failed to converge
if not optimization_success:
error_message = "Optimization failed to converge."
return False, error_message, "❌"
return True, "Optimization successful.", "💸"
# Function to check if the optimization target is achievable within the given bounds
def check_target_achievability(
optimize_allow,
optimization_goal,
lower_achievable_target,
upper_achievable_target,
total_absolute_target,
):
# Format the messages with appropriate numerization and naming
given_input = "response metric" if optimization_goal == "Spend" else "spends"
# Combined achievable message
achievable_message = (
f"Achievable {optimization_goal} with the given {given_input} and bounds ranges from "
f"{numerize(lower_achievable_target / st.session_state.multiplier)} to "
f"{numerize(upper_achievable_target / st.session_state.multiplier)}"
)
# Check if the target is within achievable bounds
if (lower_achievable_target > total_absolute_target) or (
upper_achievable_target < total_absolute_target
):
# Update session state with the error message
st.session_state.message_display = {
"type": "error",
"message": achievable_message,
"icon": "⚠️",
}
optimize_allow = False
elif (st.session_state.message_display["message"] is not None) and (
str(st.session_state.message_display["message"]).startswith("Achievable")
):
# Clear message_display
st.session_state.message_display = {
"type": "success",
"message": None,
"icon": "",
}
optimize_allow = True
return optimize_allow
# Function to display a message with the appropriate type and icon
def display_message():
# Retrieve the message details from the session state
message_type = st.session_state.message_display["type"]
message = st.session_state.message_display["message"]
icon = st.session_state.message_display["icon"]
# Display the message if it exists
if message is not None:
if message_type == "success":
st.success(message, icon=icon)
# Log message
log_message("info", message, "Scenario Planner")
elif message_type == "warning":
st.warning(message, icon=icon)
# Log message
log_message("warning", message, "Scenario Planner")
elif message_type == "error":
st.error(message, icon=icon)
# Log message
log_message("error", message, "Scenario Planner")
else:
st.info(message, icon=icon)
# Log message
log_message("info", message, "Scenario Planner")
# Function to change bounds for all channels
def all_bound_change(channel_list, apply_all=False):
# Fetch updated upper and lower bounds for all channels
all_lower_bound = st.session_state["all_lower_key"]
all_upper_bound = st.session_state["all_upper_key"]
# Check if lower bound is not greater than upper bound
if all_lower_bound < all_upper_bound:
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
]
# Update the bounds for the all channels
if apply_all:
for channel in channel_list:
if not data[metrics_selected][panel_selected]["channels"][channel][
"freeze"
]:
data[metrics_selected][panel_selected]["channels"][channel][
"bounds"
] = [
all_lower_bound,
all_upper_bound,
]
# Update the bounds for the all channels holder
data[metrics_selected][panel_selected]["bounds"] = [
all_lower_bound,
all_upper_bound,
]
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = data
else:
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Lower bound cannot be greater than Upper bound.",
"icon": "⚠️",
}
return
try:
# Page Title
st.title("Scenario Planner")
# Retrieve the list of all metric names from the specified directory
metrics_list = get_metrics_names()
# Check if there are any metrics available in the metrics list
if not metrics_list:
# Display a warning message to the user if no metrics are found
st.warning(
"Please tune at least one model to generate response curves data.",
icon="⚠️",
)
# Log message
log_message(
"warning",
"Please tune at least one model to generate response curves data.",
"Scenario Planner",
)
st.stop()
# Widget columns
metric_col, panel_col, timeframe_col, save_progress_col = st.columns(4)
# Metrics Selection
metrics_selected = metric_col.selectbox(
"Response Metrics",
sorted(metrics_list),
format_func=name_formating,
key="response_metrics_selectbox_sp",
index=0,
)
metrics_selected_formatted = name_formating(metrics_selected)
# Retrieve the list of all panel names for specified Metrics
panel_list = get_panels_names(metrics_selected)
# Panel Selection
panel_selected = panel_col.selectbox(
"Panel",
sorted(panel_list),
format_func=name_formating,
key="panel_selected_selectbox_sp",
index=0,
)
panel_selected_formatted = name_formating(panel_selected)
# Timeframe Selection
timeframe_selected = timeframe_col.selectbox(
"Timeframe",
["Input Data Range", "Yearly", "Quarterly", "Monthly"],
key="timeframe_selected_selectbox_sp",
index=0,
)
# Check if the RCS metadata file does not exist
if (
st.session_state["project_dct"]["response_curves"]["original_metadata_file"]
is None
or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
is None
):
# RCS metadata file does not exist. Generating new RCS data
generate_rcs_data()
# Log message
log_message(
"info",
"RCS metadata file does not exist. Generating new RCS data.",
"Scenario Planner",
)
# Load rcs metadata files if they exist
original_rcs_data, modified_rcs_data = load_rcs_metadata_files()
# Check if the scenario metadata file does not exist
if (
st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"]
is None
or st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
is None
):
# Scenario file does not exist. Generating new senario file data
generate_scenario_data()
# Load scenario metadata files if they exist
original_data, modified_data = load_scenario_metadata_files()
try:
# Data date range
date_range = pd.to_datetime(
list(original_data[metrics_selected][panel_selected]["channels"].values())[
0
]["dates"]
)
# Calculate the number of days between max and min dates
date_diff = pd.Series(date_range).diff()
day_data = int(
(date_range.max() - date_range.min()).days
+ (6 if date_diff.value_counts().idxmax() == pd.Timedelta(weeks=1) else 0)
)
# Set the multiplier based on the selected timeframe
if timeframe_selected == "Input Data Range":
st.session_state["multiplier"] = 1
elif timeframe_selected == "Yearly":
st.session_state["multiplier"] = day_data / 365
elif timeframe_selected == "Quarterly":
st.session_state["multiplier"] = day_data / 90
elif timeframe_selected == "Monthly":
st.session_state["multiplier"] = day_data / 30
except:
st.session_state["multiplier"] = 1
# Extract original scenario data for the selected metric and panel
original_scenario_data = original_data[metrics_selected][panel_selected]
# Extract modified scenario data for the same metric and panel
modified_scenario_data = modified_data[metrics_selected][panel_selected]
# Display Actual Vs Optimized
st.divider()
(
actual_spends_col,
actual_metrics_col,
actual_CPA_col,
base_col,
optimized_spends_col,
optimized_metrics_col,
optimized_CPA_col,
) = st.columns([1, 1, 1, 1, 1.5, 1.5, 1.5])
# Base Contribution
base_contribution = (
sum(original_scenario_data["constant"]) / st.session_state["multiplier"]
)
# Display Base Metric
base_col.metric(
f"Base {metrics_selected_formatted}",
numerize(base_contribution),
)
# Extracting and formatting values
actual_spends = numerize(
original_scenario_data["actual_total_spends"] / st.session_state["multiplier"]
)
actual_metric_value = numerize(
original_scenario_data["actual_total_sales"] / st.session_state["multiplier"]
)
optimized_spends = numerize(
modified_scenario_data["modified_total_spends"] / st.session_state["multiplier"]
)
optimized_metric_value = numerize(
modified_scenario_data["modified_total_sales"] / st.session_state["multiplier"]
)
# Calculate the deltas (differences) for spends and metrics
spends_delta_value = (
modified_scenario_data["modified_total_spends"]
- original_scenario_data["actual_total_spends"]
) / st.session_state["multiplier"]
metrics_delta_value = (
modified_scenario_data["modified_total_sales"]
- original_scenario_data["actual_total_sales"]
) / st.session_state["multiplier"]
# Calculate the percentage changes for spends and metrics
spends_percentage_change = (
spends_delta_value
/ (
original_scenario_data["actual_total_spends"]
/ st.session_state["multiplier"]
)
) * 100
metrics_percentage_change_media = (
metrics_delta_value
/ (
(
original_scenario_data["actual_total_sales"]
/ st.session_state["multiplier"]
)
- base_contribution
)
) * 100
metrics_percentage_change_all = (
metrics_delta_value
/ (
original_scenario_data["actual_total_sales"]
/ st.session_state["multiplier"]
)
) * 100
# Format the percentage change for display
spends_percentage_display = (
f"({round(spends_percentage_change, 1)}%)"
if abs(spends_percentage_change) >= 0.1
else "(0%)"
)
metrics_percentage_display_media = (
f"({round(metrics_percentage_change_media, 1)}%)"
if abs(metrics_percentage_change_media) >= 0.1
else "(0%)"
)
metrics_percentage_display_all = (
f"({round(metrics_percentage_change_all, 1)}%)"
if abs(metrics_percentage_change_all) >= 0.1
else "(0%)"
)
# Check if the delta for spends is less than 0.1% in absolute terms
if abs(spends_delta_value) < 0.001 * original_scenario_data["actual_total_spends"]:
spends_delta = "0"
else:
spends_delta = numerize(spends_delta_value)
# Check if the delta for metrics is less than 0.1% in absolute terms
if abs(metrics_delta_value) < 0.001 * original_scenario_data["actual_total_sales"]:
metrics_delta = "0"
else:
metrics_delta = numerize(metrics_delta_value)
# Display current and optimized CPA
actual_CPA = (
original_scenario_data["actual_total_spends"]
/ original_scenario_data["actual_total_sales"]
)
optimized_CPA = (
modified_scenario_data["modified_total_spends"]
/ modified_scenario_data["modified_total_sales"]
)
CPA_delta_value = optimized_CPA - actual_CPA
# Calculate the percentage change for CPA
CPA_percentage_change = (
((CPA_delta_value / actual_CPA) * 100) if actual_CPA != 0 else 0
)
CPA_percentage_display = (
f"({round(CPA_percentage_change, 1)}%)"
if abs(CPA_percentage_change) >= 0.1
else "(0%)"
)
# Check if the CPA delta is less than 0.1% in absolute terms
if abs(CPA_delta_value) < 0.001 * actual_CPA:
CPA_delta = "0"
else:
CPA_delta = round_value(CPA_delta_value)
# Display the metrics with percentage changes
actual_CPA_col.metric(
"Actual CPA",
(numerize(actual_CPA) if actual_CPA >= 1000 else round_value(actual_CPA)),
)
optimized_spends_col.metric(
"Optimized Spend",
f"{optimized_spends} {spends_percentage_display}",
delta=spends_delta,
)
optimized_metrics_col.metric(
f"Optimized {metrics_selected_formatted}",
f"{optimized_metric_value} {metrics_percentage_display_all}",
delta=f"{metrics_delta} {metrics_percentage_display_media}",
)
optimized_CPA_col.metric(
"Optimized CPA",
(
f"{numerize(optimized_CPA) if optimized_CPA >= 1000 else round_value(optimized_CPA)} {CPA_percentage_display}"
),
delta=CPA_delta,
delta_color="inverse",
)
# Displaying metrics in the columns
actual_spends_col.metric("Actual Spend", actual_spends)
actual_metrics_col.metric(
f"Actual {metrics_selected_formatted}",
actual_metric_value,
)
# Check if the percentage display for media starts with a negative sign
if str(metrics_percentage_display_all[1:]).startswith("-"):
# If negative, set the color to red
metrics_percentage_display_media_str = f'<span style="color:rgb(255, 43, 43)">red <strong>{metrics_percentage_display_media}</strong></span>'
else:
# If positive, set the color to green
metrics_percentage_display_media_str = f'<span style="color:rgb(9, 171, 59)">green <strong>{metrics_percentage_display_media}</strong></span>'
# Display percentage calculation note
st.markdown(
f"**Note:** The percentage change for the response metric in {metrics_percentage_display_media_str} reflects the change based on the media-driven portion only, excluding the fixed base contribution and the percentage in black **{metrics_percentage_display_all}** represents the change based on the total response metric, including the base contribution. For spends, the percentage change **{spends_percentage_display}** is based on the total actual spends (base spends are always zero).",
unsafe_allow_html=True,
)
# Divider
st.divider()
# Calculate ROI threshold
st.session_state.roi_threshold = (
original_scenario_data["actual_total_sales"]
- sum(original_scenario_data["constant"])
) / original_scenario_data["actual_total_spends"]
# Fetch and sort channels based on actual spends
channel_list = list(
sorted(
original_scenario_data["channels"],
key=lambda channel: (
original_scenario_data["channels"][channel]["actual_total_spends"]
* original_scenario_data["channels"][channel]["conversion_rate"]
),
reverse=True,
)
)
# Create columns for optimization goal and buttons
(
optimization_goal_col,
message_display_col,
button_col,
bounds_col,
) = st.columns([3, 6, 3, 3])
# Display spinnner or message
with message_display_col:
st.write("###")
spinner_placeholder = st.empty()
# Save Progress
with save_progress_col:
st.write("####") # Padding
save_progress_placeholder = st.empty()
# Save page progress
with spinner_placeholder, st.spinner("Saving Progress ..."):
if save_progress_placeholder.button("Save Progress", use_container_width=True):
# Update DB
update_db(
prj_id=st.session_state["project_number"],
page_nam="Scenario Planner",
file_nam="project_dct",
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
schema=schema,
)
# Store the message details in session state
with message_display_col:
st.session_state.message_display = {
"type": "success",
"message": "Progress saved successfully!",
"icon": "💾",
}
st.toast("Progress saved successfully!", icon="💾")
# Create columns for absolute text, slider, percentage number and bound type
absolute_text_col, absolute_slider_col, percentage_number_col, all_bounds_col = (
st.columns([2, 4, 2, 2])
)
# Dropdown for selecting optimization goal
optimization_goal = optimization_goal_col.selectbox(
"Fix", ["Spend", metrics_selected_formatted]
)
# Button columns with padding for alignment
with button_col:
st.write("##") # Padding
optimize_button_col, reset_button_col = st.columns(2)
reset_button_col.button(
"Reset",
use_container_width=True,
on_click=reset_scenario,
args=(metrics_selected, panel_selected),
)
# Absolute value display
if optimization_goal == "Spend":
absolute_value = modified_scenario_data["actual_total_spends"]
st.session_state.total_absolute_main_key = numerize(
modified_scenario_data["modified_total_spends"]
/ st.session_state["multiplier"]
)
else:
absolute_value = modified_scenario_data["actual_total_sales"]
st.session_state.total_absolute_main_key = numerize(
modified_scenario_data["modified_total_sales"]
/ st.session_state["multiplier"]
)
total_absolute = absolute_text_col.text_input(
"Absolute",
key="total_absolute_main_key",
on_change=total_absolute_main_key_change,
args=(
metrics_selected,
panel_selected,
optimization_goal,
),
)
# Generate and process slider options
slider_options = list(
np.linspace(int(0.5 * absolute_value), int(1.5 * absolute_value), 50)
) # Generate range
slider_options.append(
modified_scenario_data["modified_total_spends"]
if optimization_goal == "Spend"
else modified_scenario_data["modified_total_sales"]
)
slider_options = sorted(slider_options) # Sort the list
numerized_slider_options = [
numerize(value / st.session_state["multiplier"]) for value in slider_options
] # Numerize each value
# Slider for adjusting absolute value within a range
st.session_state.total_absolute_key = numerize(
modified_scenario_data["modified_total_spends"] / st.session_state["multiplier"]
if optimization_goal == "Spend"
else modified_scenario_data["modified_total_sales"]
/ st.session_state["multiplier"]
)
slider_value = absolute_slider_col.select_slider(
"Absolute",
numerized_slider_options,
key="total_absolute_key",
on_change=total_absolute_key_change,
args=(
metrics_selected,
panel_selected,
optimization_goal,
),
)
# Number input for percentage value
if optimization_goal == "Spend":
st.session_state.total_percentage_key = int(
round(
(
(
modified_scenario_data["modified_total_spends"]
- modified_scenario_data["actual_total_spends"]
)
/ modified_scenario_data["actual_total_spends"]
)
* 100,
0,
)
)
else:
st.session_state.total_percentage_key = int(
round(
(
(
modified_scenario_data["modified_total_sales"]
- modified_scenario_data["actual_total_sales"]
)
/ modified_scenario_data["actual_total_sales"]
)
* 100,
0,
)
)
percentage_target = percentage_number_col.number_input(
"Percentage",
min_value=-50,
max_value=50,
key="total_percentage_key",
on_change=total_percentage_key_change,
args=(
metrics_selected,
panel_selected,
absolute_value,
optimization_goal,
),
)
# Toggle input for bound type
st.session_state["bound_type_key"] = modified_scenario_data["bound_type"]
with bounds_col:
st.write("##") # Padding
# Columns for custom bounds toggle and apply all bounds button
allow_custom_bounds_col, apply_all_bounds_col = st.columns(2)
# Toggle for enabling/disabling custom bounds
bound_type = allow_custom_bounds_col.toggle(
"Bounds",
on_change=bound_type_change,
key="bound_type_key",
)
# Button to apply all bounds
apply_all_bounds = apply_all_bounds_col.button(
"Apply All",
use_container_width=True,
on_click=all_bound_change,
args=(channel_list, True),
disabled=not bound_type,
)
# Section for setting all lower and upper bounds
with all_bounds_col:
lower_bound_all, upper_bound_all = st.columns([1, 1])
# Initialize session state keys for lower and upper bounds
st.session_state["all_lower_key"] = (modified_scenario_data["bounds"])[0]
st.session_state["all_upper_key"] = (modified_scenario_data["bounds"])[1]
# Input for all lower bounds
all_lower_bound = lower_bound_all.number_input(
"All Lower Bounds",
min_value=-100,
max_value=100,
key="all_lower_key",
on_change=all_bound_change,
args=(channel_list, False),
disabled=not bound_type,
)
# Input for all upper bounds
all_upper_bound = upper_bound_all.number_input(
"All Upper Bounds",
min_value=-100,
max_value=100,
key="all_upper_key",
on_change=all_bound_change,
args=(channel_list, False),
disabled=not bound_type,
)
# Collect inputs from the user interface
total_channel_spends, optimize_allow = 0, True
bounds_dict = {}
s_curve_params = {}
channels_spends = {}
channels_proportion = {}
channels_conversion_ratio = {}
channels_name_plot_placeholder = {}
# Optimization Inputs UI
with st.expander("Optimization Inputs", expanded=True):
# Initialize total contributions for actual and optimized spends and metrics
(
total_actual_spend_contribution,
total_actual_metric_contribution,
total_optimized_spend_contribution,
total_optimized_metric_contribution,
) = (
0,
sum(modified_scenario_data["constant"]),
0,
sum(modified_scenario_data["constant"]),
)
# Iterate over each channel in the channel list
for channel in channel_list:
# Accumulate actual total spends
total_actual_spend_contribution += (
modified_scenario_data["channels"][channel]["actual_total_spends"]
* modified_scenario_data["channels"][channel]["conversion_rate"]
)
# Accumulate actual total sales (metrics)
total_actual_metric_contribution += modified_scenario_data["channels"][
channel
]["actual_total_sales"]
# Accumulate optimized total spends
total_optimized_spend_contribution += (
modified_scenario_data["channels"][channel]["modified_total_spends"]
* modified_scenario_data["channels"][channel]["conversion_rate"]
)
# Accumulate optimized total sales (metrics)
total_optimized_metric_contribution += modified_scenario_data["channels"][
channel
]["modified_total_sales"]
for channel in channel_list:
st.divider()
# Channel key
channel_key = f"{metrics_selected}_{panel_selected}_{channel}"
# Create columns
if st.session_state["bound_type_key"]:
(
name_plot_col,
input_col,
spends_col,
metrics_col,
bounds_input_col,
bounds_display_col,
allow_col,
) = st.columns([3, 2, 2, 2, 2, 2, 1])
else:
(
name_plot_col,
input_col,
spends_col,
metrics_col,
bounds_display_col,
allow_col,
) = st.columns([1.5, 1, 1.5, 1.5, 1, 0.5])
bounds_input_col = st.empty()
# Display channel name and ROI/MROI plot
with name_plot_col:
# Placeholder for channel name
channel_name_placeholder = st.empty()
channel_name_placeholder.markdown(
display_channel_name_with_background_color(channel),
unsafe_allow_html=True,
)
# Placeholder for ROI and MROI plot
channel_plot_placeholder = st.container()
# Store placeholder for channel name and ROI/MROI plots
channels_name_plot_placeholder[channel] = {
"channel_name_placeholder": channel_name_placeholder,
"channel_plot_placeholder": channel_plot_placeholder,
}
# Channel spends and sales
channel_spends_actual = (
original_scenario_data["channels"][channel]["actual_total_spends"]
* original_scenario_data["channels"][channel]["conversion_rate"]
)
channel_metrics_actual = original_scenario_data["channels"][channel][
"actual_total_sales"
]
channel_spends_modified = (
modified_scenario_data["channels"][channel]["modified_total_spends"]
* original_scenario_data["channels"][channel]["conversion_rate"]
)
channel_metrics_modified = modified_scenario_data["channels"][channel][
"modified_total_sales"
]
# Channel spends input
with input_col:
# Absolute Spends Input
st.session_state[f"{channel_key}_abs_spends_key"] = numerize(
modified_scenario_data["channels"][channel]["modified_total_spends"]
* original_scenario_data["channels"][channel]["conversion_rate"]
/ st.session_state["multiplier"]
)
absolute_channel_spends = st.text_input(
"Absolute Spends",
key=f"{channel_key}_abs_spends_key",
on_change=absolute_channel_spends_change,
args=(
channel_key,
channel_spends_actual,
channel,
metrics_selected,
panel_selected,
),
)
# Update Percentage Spends Input
st.session_state[f"{channel_key}_per_spends_key"] = int(
round(
(
(
convert_to_float(
st.session_state[f"{channel_key}_abs_spends_key"]
)
* st.session_state["multiplier"]
- float(channel_spends_actual)
)
/ channel_spends_actual
)
* 100,
0,
)
)
# Percentage Spends Input
percentage_channel_spends = st.number_input(
"Percentage Spends",
min_value=-1000,
max_value=1000,
key=f"{channel_key}_per_spends_key",
on_change=percentage_channel_spends_change,
args=(
channel_key,
channel_spends_actual,
channel,
metrics_selected,
panel_selected,
),
)
# Store channel spends, conversion ratio and proportion list
channels_spends[channel] = original_scenario_data["channels"][channel][
"actual_total_spends"
] * (1 + percentage_channel_spends / 100)
channels_conversion_ratio[channel] = original_scenario_data["channels"][
channel
]["conversion_rate"]
channels_proportion[channel] = original_scenario_data["channels"][
channel
]["spends"] / sum(original_scenario_data["channels"][channel]["spends"])
# Calculate the percent contribution of actual spends for the channel
channel_actual_spend_contribution = round(
(
modified_scenario_data["channels"][channel][
"actual_total_spends"
]
* channels_conversion_ratio[channel]
/ total_actual_spend_contribution
)
* 100,
1,
)
# Calculate the percent contribution of actual metrics (sales) for the channel
channel_actual_metric_contribution = round(
(
modified_scenario_data["channels"][channel][
"actual_total_sales"
]
/ total_actual_metric_contribution
)
* 100,
1,
)
# Calculate the percent contribution of optimized spends for the channel
channel_optimized_spend_contribution = round(
(
modified_scenario_data["channels"][channel][
"modified_total_spends"
]
* channels_conversion_ratio[channel]
/ total_optimized_spend_contribution
)
* 100,
1,
)
# Calculate the percent contribution of optimized metrics (sales) for the channel
channel_optimized_metric_contribution = round(
(
modified_scenario_data["channels"][channel][
"modified_total_sales"
]
/ total_optimized_metric_contribution
)
* 100,
1,
)
# Channel metrics display
with metrics_col:
# Absolute Metrics
st.metric(
f"Actual {name_formating(metrics_selected)}",
value=str(
numerize(
channel_metrics_actual / st.session_state["multiplier"]
)
)
+ f"({channel_actual_metric_contribution}%)",
)
# Optimized Metrics
optimized_metric = (
channel_metrics_modified / st.session_state["multiplier"]
)
actual_metric = channel_metrics_actual / st.session_state["multiplier"]
delta_value = (
channel_metrics_modified - channel_metrics_actual
) / st.session_state["multiplier"]
# Check if the delta is less than 0.1% in absolute terms
if (
abs(delta_value) < 0.001 * actual_metric
): # 0.1% of the actual metric
delta_display = "0"
else:
delta_display = numerize(delta_value)
st.metric(
f"Optimized {name_formating(metrics_selected)}",
value=str(numerize(optimized_metric))
+ f"({channel_optimized_metric_contribution}%)",
delta=delta_display,
)
# Channel spends display
with spends_col:
# Absolute Spends
st.metric(
"Actual Spend",
value=str(
numerize(channel_spends_actual / st.session_state["multiplier"])
)
+ f"({channel_actual_spend_contribution}%)",
)
# Optimized Spends
optimized_spends = (
channel_spends_modified / st.session_state["multiplier"]
)
actual_spends = channel_spends_actual / st.session_state["multiplier"]
delta_spends_value = (
channel_spends_modified - channel_spends_actual
) / st.session_state["multiplier"]
# Check if the delta is less than 0.1% in absolute terms
if (
abs(delta_spends_value) < 0.001 * actual_spends
): # 0.1% of the actual spend
delta_spends_display = "0"
else:
delta_spends_display = numerize(delta_spends_value)
st.metric(
"Optimized Spend",
value=str(numerize(optimized_spends))
+ f"({channel_optimized_spend_contribution}%)",
delta=delta_spends_display,
)
# Channel allows optimize
with allow_col:
# Allow Optimize (Freeze)
st.write("#") # Padding
st.session_state[f"{channel_key}_allow_optimize_key"] = (
modified_scenario_data["channels"][channel]["freeze"]
)
freeze = st.checkbox(
"Freeze",
key=f"{channel_key}_allow_optimize_key",
on_change=freeze_change,
args=(
metrics_selected,
panel_selected,
channel_key,
channel,
channel_list,
),
)
# If channel is frozen, set bounds to keep the spend unchanged
if freeze:
lower_bound, upper_bound = 0, 0 # Freeze the spend at current level
# Channel bounds input
if st.session_state["bound_type_key"]:
with bounds_input_col:
# Channel upper bound
st.session_state[f"{channel_key}_upper_key"] = (
modified_scenario_data["channels"][channel]["bounds"]
)[1]
upper_bound = st.number_input(
"Upper bound (%)",
min_value=-100,
max_value=100,
key=f"{channel_key}_upper_key",
disabled=st.session_state[f"{channel_key}_allow_optimize_key"],
on_change=bound_change,
args=(
metrics_selected,
panel_selected,
channel_key,
channel,
),
)
# Channel lower bound
st.session_state[f"{channel_key}_lower_key"] = (
modified_scenario_data["channels"][channel]["bounds"]
)[0]
lower_bound = st.number_input(
"Lower bound (%)",
min_value=-100,
max_value=100,
key=f"{channel_key}_lower_key",
disabled=st.session_state[f"{channel_key}_allow_optimize_key"],
on_change=bound_change,
args=(
metrics_selected,
panel_selected,
channel_key,
channel,
),
)
# Check if lower bound is greater than upper bound
if lower_bound > upper_bound:
lower_bound = -10 # Default lower bound
upper_bound = 10 # Default upper bound
# Store bounds
bounds_dict[channel] = [lower_bound, upper_bound]
else:
# If channel is frozen, set bounds to keep the spend unchanged
if freeze:
lower_bound, upper_bound = 0, 0 # Freeze the spend at current level
else:
lower_bound = -10 # Default lower bound
upper_bound = 10 # Default upper bound
# Store bounds
bounds_dict[channel] = modified_scenario_data["channels"][channel][
"bounds"
]
# Display the bounds for each channel's spend in the bounds_display_col
with bounds_display_col:
# Retrieve the actual spends for the channel from the original scenario data
actual_spends = (
modified_scenario_data["channels"][channel]["modified_total_spends"]
* modified_scenario_data["channels"][channel]["conversion_rate"]
)
# Calculate the limit for spends
upper_limit_spends = actual_spends * (1 + upper_bound / 100)
lower_limit_spends = actual_spends * (1 + lower_bound / 100)
# Display the upper limit spends
st.metric(
"Upper Bound",
numerize(upper_limit_spends / st.session_state["multiplier"]),
)
st.metric(
"Lower Bound",
numerize(lower_limit_spends / st.session_state["multiplier"]),
)
# Store S-curve parameters
s_curve_params[channel] = get_s_curve_params(
metrics_selected,
panel_selected,
channel,
original_rcs_data,
modified_rcs_data,
)
# Total channel spends
total_channel_spends += (
convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"])
* st.session_state["multiplier"]
)
# Check if total channel spends are within the allowed range (±50% of the original total spends)
if (
total_channel_spends > 1.5 * original_scenario_data["actual_total_spends"]
or total_channel_spends
< 0.5 * original_scenario_data["actual_total_spends"]
):
# Store the message details in session state
st.session_state.message_display = {
"type": "warning",
"message": "Keep total spending within ±50% of the original value.",
"icon": "⚠️",
}
if optimization_goal == "Spend":
# Get maximum achievable spends
lower_achievable_target, upper_achievable_target = 0, 0
for channel in channel_list:
channel_spends_actual = (
channels_spends[channel] * channels_conversion_ratio[channel]
)
lower_achievable_target += channel_spends_actual * (
1 + bounds_dict[channel][0] / 100
)
upper_achievable_target += channel_spends_actual * (
1 + bounds_dict[channel][1] / 100
)
else:
# Get maximum achievable target metric
lower_achievable_target, upper_achievable_target = max_target_achievable(
channels_spends,
s_curve_params,
channels_proportion,
modified_scenario_data,
bounds_dict,
)
# Total target of selected metric
if optimization_goal == "Spend":
total_absolute_target = modified_scenario_data["modified_total_spends"]
else:
total_absolute_target = modified_scenario_data["modified_total_sales"]
# Check if the target is achievable within the specified bounds
if optimize_allow:
optimize_allow = check_target_achievability(
optimize_allow,
name_formating(optimization_goal),
lower_achievable_target,
upper_achievable_target,
total_absolute_target,
)
# Perform the optimization
if optimize_button_col.button(
"Optimize",
use_container_width=True,
disabled=not optimize_allow,
key="run_optimizer",
):
with message_display_col:
with spinner_placeholder, st.spinner("Optimizing ..."):
# Call the optimizer function to get optimized spends
optimized_spends, optimization_success = optimizer(
optimization_goal,
s_curve_params,
channels_spends,
channels_proportion,
channels_conversion_ratio,
total_absolute_target,
bounds_dict,
modified_scenario_data,
)
# Initialize dictionaries to store input and output channel spends
input_channels_spends, output_channels_spends = {}, {}
for channel in channel_list:
# Calculate input channel spends by converting spends using conversion ratio
input_channels_spends[channel] = (
channels_spends[channel] * channels_conversion_ratio[channel]
)
# Calculate output channel spends by converting optimized spends using conversion ratio
output_channels_spends[channel] = (
optimized_spends[channel] * channels_conversion_ratio[channel]
)
# Calculate total actual and modified spends
actual_total_spends = sum(list(input_channels_spends.values()))
modified_total_spends = sum(list(output_channels_spends.values()))
# Retrieve the actual total metrics from modified scenario data
actual_total_metrics = modified_scenario_data["modified_total_sales"]
modified_total_metrics = 0 # Initialize modified total metrics
modified_channels_metrics = {}
# Calculate modified metrics for each channel
for channel in optimized_spends.keys():
channel_s_curve_params = s_curve_params[channel]
spend_proportion = (
optimized_spends[channel] * channels_proportion[channel]
)
# Calculate the metrics using the S-curve function
modified_channels_metrics[channel] = sum(
s_curve(
spend_proportion,
channel_s_curve_params["power"],
channel_s_curve_params["K"],
channel_s_curve_params["b"],
channel_s_curve_params["a"],
channel_s_curve_params["x0"],
)
) + sum(
modified_scenario_data["channels"][channel]["correction"]
) # correction for s-curve
modified_total_metrics += modified_channels_metrics[
channel
] # Add channel metrics to total metrics
# Add the constant and correction term to the modified total metrics
modified_total_metrics += sum(modified_scenario_data["constant"])
# Retrieve the original total spends from modified scenario data
original_total_spends = modified_scenario_data["actual_total_spends"]
# Check the success of the optimization process
success, message, icon = check_optimization_success(
channel_list,
input_channels_spends,
output_channels_spends,
bounds_dict,
optimization_goal,
modified_total_metrics,
actual_total_metrics,
modified_total_spends,
actual_total_spends,
original_total_spends,
optimization_success,
)
# Store the message details in session state
st.session_state.message_display = {
"type": "success" if success else "error",
"message": message,
"icon": icon,
}
# Update data only if the optimization is successful
if success:
# Update the modified spend and metrics for each channel in the scenario data
for channel in channel_list:
modified_scenario_data["channels"][channel][
"modified_total_spends"
] = optimized_spends[channel]
# Update the modified metrics for each channel in the scenario data
modified_scenario_data["channels"][channel][
"modified_total_sales"
] = modified_channels_metrics[channel]
# Update the total modified spends in the scenario data
modified_scenario_data["modified_total_spends"] = (
modified_total_spends
)
# Update the total modified metrics in the scenario data
modified_scenario_data["modified_total_sales"] = (
modified_total_metrics
)
# Load modified scenario data
data = st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
]
# Update the specific section with the modified scenario data
data[metrics_selected][panel_selected] = modified_scenario_data
# Update modified scenario metadata
st.session_state["project_dct"]["scenario_planner"][
"modified_metadata_file"
] = data
# Reset optimizer button
del st.session_state["run_optimizer"]
# Rerun to update values
st.rerun()
########################################## Response Curves ##########################################
# Generate plots
figures, channel_roi_mroi, region_start_end = generate_response_curve_plots(
channel_list,
s_curve_params,
channels_proportion,
original_scenario_data,
st.session_state["multiplier"],
)
# Display Response Curves
st.subheader(f"Response Curves (X: Spends Vs Y: {metrics_selected_formatted})")
with st.expander("Response Curves", expanded=True):
cols = st.columns(4) # Create 4 columns for the first row
for i, fig in enumerate(figures):
col = cols[i % 4] # Rotate through the columns
with col:
# Get channel parameters
channel = channel_list[i]
modified_total_spends = modified_scenario_data["channels"][channel][
"modified_total_spends"
]
conversion_rate = modified_scenario_data["channels"][channel][
"conversion_rate"
]
channel_correction = sum(
modified_scenario_data["channels"][channel]["correction"]
)
# Updated figure with modified metrics point
roi_optimized, mroi_optimized, fig_updated = modified_metrics_point(
fig,
modified_total_spends,
s_curve_params[channel],
channels_proportion[channel],
conversion_rate,
channel_correction,
)
# Store data of each channel ROI and MROI
channel_roi_mroi[channel]["optimized_roi"] = roi_optimized
channel_roi_mroi[channel]["optimized_mroi"] = mroi_optimized
st.plotly_chart(fig_updated, use_container_width=True)
# Start a new row after every 4 plots
if (i + 1) % 4 == 0 and i + 1 < len(figures):
cols = st.columns(4) # Create new row with 4 columns
# Generate the plots
channel_roi_mroi_plot = roi_mori_plot(channel_roi_mroi)
# Display the plots and name with background color
for channel in channel_list:
with channels_name_plot_placeholder[channel]["channel_plot_placeholder"]:
# Create subplots with 2 columns for ROI and MROI
roi_plot_col, mroi_plot_col = st.columns(2)
# Display ROI and MROI plots
roi_plot_col.plotly_chart(channel_roi_mroi_plot[channel]["fig_roi"])
mroi_plot_col.plotly_chart(channel_roi_mroi_plot[channel]["fig_mroi"])
# Placeholder for the channel name
channel_name_placeholder = channels_name_plot_placeholder[channel][
"channel_name_placeholder"
]
# Retrieve modified total spends and conversion rate for the channel
modified_total_spends = modified_scenario_data["channels"][channel][
"modified_total_spends"
]
conversion_rate = modified_scenario_data["channels"][channel]["conversion_rate"]
# Calculate the actual spend value for the channel
channel_spends_value = modified_total_spends * conversion_rate
# Calculate the RGBA color value for the channel based on its spend
channel_rgba_value = calculate_rgba(
channel_spends_value, region_start_end[channel]
)
# Display the channel name with the calculated background color
channel_name_placeholder.markdown(
display_channel_name_with_background_color(channel, channel_rgba_value),
unsafe_allow_html=True,
)
# Input field for the scenario name
st.text_input("Scenario Name", key="scenario_name")
# Disable the "Save Scenario" button until a name is provided
if (
st.session_state["scenario_name"] is None
or st.session_state["scenario_name"] == ""
):
save_scenario_button_disabled = True
else:
save_scenario_button_disabled = False
# Button to save the scenario
save_button_placeholder = st.empty()
with st.spinner("Saving ..."):
save_button_placeholder.button(
"Save Scenario",
on_click=save_scenario,
args=(
modified_scenario_data,
metrics_selected,
panel_selected,
optimization_goal,
channel_roi_mroi,
st.session_state["timeframe_selected_selectbox_sp"],
st.session_state["multiplier"],
),
disabled=save_scenario_button_disabled,
)
########################################## Display Message ##########################################
# Display all message
with message_display_col:
display_message()
except Exception as e:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
# Log message
log_message("error", f"An error occurred: {error_message}.", "Scenario Planner")
# Display a warning message
st.warning(
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
icon="⚠️",
)