MediaMixOptimization / pages /8_Response_Curves.py
samkeet's picture
Upload 40 files
00b00eb verified
raw
history blame
18 kB
import streamlit as st
st.set_page_config(
page_title="Response Curves",
page_icon="⚖️",
layout="wide",
initial_sidebar_state="collapsed",
)
# Disable +/- for number input
st.markdown(
"""
<style>
button.step-up {display: none;}
button.step-down {display: none;}
div[data-baseweb] {border-radius: 4px;}
</style>""",
unsafe_allow_html=True,
)
import sys
import json
import pickle
import traceback
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from post_gres_cred import db_cred
from sklearn.metrics import r2_score
from log_application import log_message
from utilities import project_selection, update_db, set_header, load_local_css
from utilities import (
get_panels_names,
get_metrics_names,
name_formating,
generate_rcs_data,
load_rcs_metadata_files,
)
schema = db_cred["schema"]
load_local_css("styles.css")
set_header()
# Initialize project name session state
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
# Fetch project dictionary
if "project_dct" not in st.session_state:
project_selection()
st.stop()
# Display Username and Project Name
if "username" in st.session_state and st.session_state["username"] is not None:
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
# Function to build s curve
def s_curve(x, K, b, a, x0):
return K / (1 + b * np.exp(-a * (x - x0)))
# Function to update the RCS parameters in the modified RCS metadata data
def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected):
# Define unique keys for each parameter based on the selection
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
# Retrieve the updated parameters from session state
K_updated, b_updated, a_updated, x0_updated = (
st.session_state[K_key],
st.session_state[b_key],
st.session_state[a_key],
st.session_state[x0_key],
)
# Load the existing modified RCS data
rcs_data_modified = st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
]
# Update the RCS parameters for the selected metric and panel
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
"K": K_updated,
"b": b_updated,
"a": a_updated,
"x0": x0_updated,
}
# Function to reset the parameters to their default values
def reset_parameters(
metrics_selected, panel_selected, channel_selected, original_channel_data
):
# Define unique keys for each parameter based on the selection
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
# Reset session state values to original data
del st.session_state[K_key]
del st.session_state[b_key]
del st.session_state[a_key]
del st.session_state[x0_key]
# Reset the modified metadata file with original parameters
rcs_data_modified = st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
]
# Update the parameters in the modified data to the original values
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
"K": original_channel_data["K"],
"b": original_channel_data["b"],
"a": original_channel_data["a"],
"x0": original_channel_data["x0"],
}
# Update the modified metadata
st.session_state["project_dct"]["response_curves"][
"modified_metadata_file"
] = rcs_data_modified
# Function to generate updated RCS parameter DataFrame
@st.cache_data(show_spinner=False)
def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected):
# Retrieve the data for the selected metric and panel
original_data_selection = original_data[metrics_selected][panel_selected]
modified_data_selection = modified_data[metrics_selected][panel_selected]
# Initialize an empty list to hold the data for the DataFrame
data = []
# Iterate through each channel in the selected metric and panel
for channel in original_data_selection:
# Extract original parameters
K_o, b_o, a_o, x0_o = (
original_data_selection[channel]["K"],
original_data_selection[channel]["b"],
original_data_selection[channel]["a"],
original_data_selection[channel]["x0"],
)
# Extract modified parameters
K_m, b_m, a_m, x0_m = (
modified_data_selection[channel]["K"],
modified_data_selection[channel]["b"],
modified_data_selection[channel]["a"],
modified_data_selection[channel]["x0"],
)
# Check if any parameters differ
if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m):
# Append the data to the list only if there is a difference
data.append(
{
"Metric": name_formating(metrics_selected),
"Panel": name_formating(panel_selected),
"Channel": name_formating(channel),
"K (Original)": K_o,
"b (Original)": b_o,
"a (Original)": a_o,
"x0 (Original)": x0_o,
"K (Modified)": K_m,
"b (Modified)": b_m,
"a (Modified)": a_m,
"x0 (Modified)": x0_m,
}
)
# Create a DataFrame from the collected data
df = pd.DataFrame(data)
return df
# Function to create JSON file for RCS data
@st.cache_data(show_spinner=False)
def create_json_file():
return json.dumps(
st.session_state["project_dct"]["response_curves"]["modified_metadata_file"],
indent=4,
)
try:
# Page Title
st.title("Response Curves")
# Retrieve the list of all metric names from the specified directory
metrics_list = get_metrics_names()
# Check if there are any metrics available in the metrics list
if not metrics_list:
# Display a warning message to the user if no metrics are found
st.warning(
"Please tune at least one model to generate response curves data.",
icon="⚠️",
)
# Log message
log_message(
"warning",
"Please tune at least one model to generate response curves data.",
"Response Curves",
)
# Stop further execution as there is no data to process
st.stop()
# Widget columns
metric_col, channel_col, panel_col, save_progress_col = st.columns(4)
# Metrics Selection
metrics_selected = metric_col.selectbox(
"Response Metrics",
sorted(metrics_list),
format_func=name_formating,
key="response_metrics_selectbox",
index=0,
)
# Retrieve the list of all panel names for specified Metrics
panel_list = get_panels_names(metrics_selected)
# Panel Selection
panel_selected = panel_col.selectbox(
"Panel",
sorted(panel_list),
format_func=name_formating,
key="panel_selected_selectbox",
index=0,
)
# Save Progress
with save_progress_col:
st.write("####") # Padding
save_progress_placeholder = st.empty()
# Placeholder to display message and spinner
message_spinner_placeholder = st.container()
# Save page progress
with message_spinner_placeholder, st.spinner("Saving Progress ..."):
if save_progress_placeholder.button("Save Progress", use_container_width=True):
# Update DB
update_db(
prj_id=st.session_state["project_number"],
page_nam="Response Curves",
file_nam="project_dct",
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
schema=schema,
)
# Store the message details in session state
message_spinner_placeholder.success(
"Progress saved successfully!", icon="💾"
)
st.toast("Progress saved successfully!", icon="💾")
# Log message
log_message("info", "Progress saved successfully!", "Response Curves")
# Check if the RCS metadata file does not exist
if (
st.session_state["project_dct"]["response_curves"]["original_metadata_file"]
is None
or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
is None
):
# RCS metadata file does not exist. Generating new RCS data
generate_rcs_data()
# Log message
log_message(
"info",
"RCS metadata file does not exist. Generating new RCS data.",
"Response Curves",
)
# Load metadata files if they exist
original_data, modified_data = load_rcs_metadata_files()
# Retrieve the list of all channels names for specified Metrics and Panel
chanel_list_final = list(original_data[metrics_selected][panel_selected].keys())
# Channel Selection
channel_selected = channel_col.selectbox(
"Channel",
sorted(chanel_list_final),
format_func=name_formating,
key="selected_channel_name_selectbox",
)
# Extract original channel data for the selected metric, panel, and channel
original_channel_data = original_data[metrics_selected][panel_selected][
channel_selected
]
# Extract modified channel data for the same metric, panel, and channel
modified_channel_data = modified_data[metrics_selected][panel_selected][
channel_selected
]
# X and Y values for plotting
x = original_channel_data["x"]
y = original_channel_data["y"]
# Scaling factor for X values and range for S-curve plotting
power = original_channel_data["power"]
x_plot = original_channel_data["x_plot"]
# Original S-curve parameters
K_orig = original_channel_data["K"]
b_orig = original_channel_data["b"]
a_orig = original_channel_data["a"]
x0_orig = original_channel_data["x0"]
# Modified S-curve parameters (user-adjusted)
K_mod = modified_channel_data["K"]
b_mod = modified_channel_data["b"]
a_mod = modified_channel_data["a"]
x0_mod = modified_channel_data["x0"]
# Create a scatter plot for the original data points
fig = px.scatter(
x=x,
y=y,
title="Original and Modified S-Curve Plot",
labels={"x": "Spends", "y": name_formating(metrics_selected)},
)
# Add the modified S-curve trace
fig.add_trace(
go.Scatter(
x=x_plot,
y=s_curve(
np.array(x_plot) / 10**power,
K_mod,
b_mod,
a_mod,
x0_mod,
),
line=dict(color="red"),
name="Modified",
),
)
# Add the original S-curve trace
fig.add_trace(
go.Scatter(
x=x_plot,
y=s_curve(
np.array(x_plot) / 10**power,
K_orig,
b_orig,
a_orig,
x0_orig,
),
line=dict(color="rgba(0, 255, 0, 0.6)"), # Semi-transparent green
name="Original",
),
)
# Customize the layout of the plot
fig.update_layout(
title="Comparison of Original and Modified Response-Curves",
xaxis_title="Input (Clicks, Impressions, etc..)",
yaxis_title=name_formating(metrics_selected),
legend_title="Curve Type",
)
# Display s-curve
st.plotly_chart(fig, use_container_width=True)
# Calculate R-squared for the original curve
y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig)
r2_orig = r2_score(y, y_orig_pred)
# Calculate R-squared for the modified curve
y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod)
r2_mod = r2_score(y, y_mod_pred)
# Calculate the difference in R-squared
r2_diff = r2_mod - r2_orig
# Display R-squared metrics
st.write("## R-squared Comparison")
r2_col = st.columns(3)
r2_col[0].metric("R-squared (Original)", f"{r2_orig:.2f}")
r2_col[1].metric("R-squared (Modified)", f"{r2_mod:.2f}")
r2_col[2].metric("Difference in R-squared", f"{r2_diff:.2f}")
# Define unique keys for each parameter based on the selection
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
# Initialize session state keys if they do not exist
if K_key not in st.session_state:
st.session_state[K_key] = K_mod
if b_key not in st.session_state:
st.session_state[b_key] = b_mod
if a_key not in st.session_state:
st.session_state[a_key] = a_mod
if x0_key not in st.session_state:
st.session_state[x0_key] = x0_mod
# RCS parameters input
rsc_ip_col = st.columns(4)
with rsc_ip_col[0]:
K_updated = st.number_input(
"K",
step=0.001,
min_value=0.000000,
format="%.6f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=K_key,
)
with rsc_ip_col[1]:
b_updated = st.number_input(
"b",
step=0.001,
min_value=0.000000,
format="%.6f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=b_key,
)
with rsc_ip_col[2]:
a_updated = st.number_input(
"a",
step=0.001,
min_value=0.000000,
format="%.6f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=a_key,
)
with rsc_ip_col[3]:
x0_updated = st.number_input(
"x0",
step=0.001,
min_value=0.000000,
format="%.6f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=x0_key,
)
# Create columns for Reset and Download buttons
reset_download_col = st.columns(2)
with reset_download_col[0]:
if st.button(
"Reset",
use_container_width=True,
):
reset_parameters(
metrics_selected,
panel_selected,
channel_selected,
original_channel_data,
)
# Log message
log_message(
"info",
f"METRIC: {name_formating(metrics_selected)} ; PANEL: {name_formating(panel_selected)}, CHANNEL: {name_formating(channel_selected)} has been reset to its original value.",
"Response Curves",
)
st.rerun()
with reset_download_col[1]:
# Provide a download button for the modified RCS data
try:
# Create JSON file for RCS data
json_data = create_json_file()
st.download_button(
label="Download",
data=json_data,
file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json",
mime="application/json",
use_container_width=True,
)
except:
# Download failed
pass
# Generate the DataFrame showing only non-matching parameters
updated_parm_df = updated_parm_gen(
original_data, modified_data, metrics_selected, panel_selected
)
# Display the DataFrame or show an informational message if no updates
if not updated_parm_df.empty:
st.write("## Parameter Comparison for Selected Metric and Panel")
st.dataframe(updated_parm_df, hide_index=True)
else:
st.info("No parameters are updated for the selected Metric and Panel")
except Exception as e:
# Capture the error details
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
# Log message
log_message("error", f"An error occurred: {error_message}.", "Response Curves")
# Display a warning message
st.warning(
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
icon="⚠️",
)