Spaces:
Running
Running
import streamlit as st | |
st.set_page_config( | |
page_title="Response Curves", | |
page_icon="⚖️", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Disable +/- for number input | |
st.markdown( | |
""" | |
<style> | |
button.step-up {display: none;} | |
button.step-down {display: none;} | |
div[data-baseweb] {border-radius: 4px;} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
import sys | |
import json | |
import pickle | |
import traceback | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from post_gres_cred import db_cred | |
from sklearn.metrics import r2_score | |
from log_application import log_message | |
from utilities import project_selection, update_db, set_header, load_local_css | |
from utilities import ( | |
get_panels_names, | |
get_metrics_names, | |
name_formating, | |
generate_rcs_data, | |
load_rcs_metadata_files, | |
) | |
schema = db_cred["schema"] | |
load_local_css("styles.css") | |
set_header() | |
# Initialize project name session state | |
if "project_name" not in st.session_state: | |
st.session_state["project_name"] = None | |
# Fetch project dictionary | |
if "project_dct" not in st.session_state: | |
project_selection() | |
st.stop() | |
# Display Username and Project Name | |
if "username" in st.session_state and st.session_state["username"] is not None: | |
cols1 = st.columns([2, 1]) | |
with cols1[0]: | |
st.markdown(f"**Welcome {st.session_state['username']}**") | |
with cols1[1]: | |
st.markdown(f"**Current Project: {st.session_state['project_name']}**") | |
# Function to build s curve | |
def s_curve(x, K, b, a, x0): | |
return K / (1 + b * np.exp(-a * (x - x0))) | |
# Function to update the RCS parameters in the modified RCS metadata data | |
def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected): | |
# Define unique keys for each parameter based on the selection | |
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
# Retrieve the updated parameters from session state | |
K_updated, b_updated, a_updated, x0_updated = ( | |
st.session_state[K_key], | |
st.session_state[b_key], | |
st.session_state[a_key], | |
st.session_state[x0_key], | |
) | |
# Load the existing modified RCS data | |
rcs_data_modified = st.session_state["project_dct"]["response_curves"][ | |
"modified_metadata_file" | |
] | |
# Update the RCS parameters for the selected metric and panel | |
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = { | |
"K": K_updated, | |
"b": b_updated, | |
"a": a_updated, | |
"x0": x0_updated, | |
} | |
# Function to reset the parameters to their default values | |
def reset_parameters( | |
metrics_selected, panel_selected, channel_selected, original_channel_data | |
): | |
# Define unique keys for each parameter based on the selection | |
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
# Reset session state values to original data | |
del st.session_state[K_key] | |
del st.session_state[b_key] | |
del st.session_state[a_key] | |
del st.session_state[x0_key] | |
# Reset the modified metadata file with original parameters | |
rcs_data_modified = st.session_state["project_dct"]["response_curves"][ | |
"modified_metadata_file" | |
] | |
# Update the parameters in the modified data to the original values | |
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = { | |
"K": original_channel_data["K"], | |
"b": original_channel_data["b"], | |
"a": original_channel_data["a"], | |
"x0": original_channel_data["x0"], | |
} | |
# Update the modified metadata | |
st.session_state["project_dct"]["response_curves"][ | |
"modified_metadata_file" | |
] = rcs_data_modified | |
# Function to generate updated RCS parameter DataFrame | |
def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected): | |
# Retrieve the data for the selected metric and panel | |
original_data_selection = original_data[metrics_selected][panel_selected] | |
modified_data_selection = modified_data[metrics_selected][panel_selected] | |
# Initialize an empty list to hold the data for the DataFrame | |
data = [] | |
# Iterate through each channel in the selected metric and panel | |
for channel in original_data_selection: | |
# Extract original parameters | |
K_o, b_o, a_o, x0_o = ( | |
original_data_selection[channel]["K"], | |
original_data_selection[channel]["b"], | |
original_data_selection[channel]["a"], | |
original_data_selection[channel]["x0"], | |
) | |
# Extract modified parameters | |
K_m, b_m, a_m, x0_m = ( | |
modified_data_selection[channel]["K"], | |
modified_data_selection[channel]["b"], | |
modified_data_selection[channel]["a"], | |
modified_data_selection[channel]["x0"], | |
) | |
# Check if any parameters differ | |
if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m): | |
# Append the data to the list only if there is a difference | |
data.append( | |
{ | |
"Metric": name_formating(metrics_selected), | |
"Panel": name_formating(panel_selected), | |
"Channel": name_formating(channel), | |
"K (Original)": K_o, | |
"b (Original)": b_o, | |
"a (Original)": a_o, | |
"x0 (Original)": x0_o, | |
"K (Modified)": K_m, | |
"b (Modified)": b_m, | |
"a (Modified)": a_m, | |
"x0 (Modified)": x0_m, | |
} | |
) | |
# Create a DataFrame from the collected data | |
df = pd.DataFrame(data) | |
return df | |
# Function to create JSON file for RCS data | |
def create_json_file(): | |
return json.dumps( | |
st.session_state["project_dct"]["response_curves"]["modified_metadata_file"], | |
indent=4, | |
) | |
try: | |
# Page Title | |
st.title("Response Curves") | |
# Retrieve the list of all metric names from the specified directory | |
metrics_list = get_metrics_names() | |
# Check if there are any metrics available in the metrics list | |
if not metrics_list: | |
# Display a warning message to the user if no metrics are found | |
st.warning( | |
"Please tune at least one model to generate response curves data.", | |
icon="⚠️", | |
) | |
# Log message | |
log_message( | |
"warning", | |
"Please tune at least one model to generate response curves data.", | |
"Response Curves", | |
) | |
# Stop further execution as there is no data to process | |
st.stop() | |
# Widget columns | |
metric_col, channel_col, panel_col, save_progress_col = st.columns(4) | |
# Metrics Selection | |
metrics_selected = metric_col.selectbox( | |
"Response Metrics", | |
sorted(metrics_list), | |
format_func=name_formating, | |
key="response_metrics_selectbox", | |
index=0, | |
) | |
# Retrieve the list of all panel names for specified Metrics | |
panel_list = get_panels_names(metrics_selected) | |
# Panel Selection | |
panel_selected = panel_col.selectbox( | |
"Panel", | |
sorted(panel_list), | |
format_func=name_formating, | |
key="panel_selected_selectbox", | |
index=0, | |
) | |
# Save Progress | |
with save_progress_col: | |
st.write("####") # Padding | |
save_progress_placeholder = st.empty() | |
# Placeholder to display message and spinner | |
message_spinner_placeholder = st.container() | |
# Save page progress | |
with message_spinner_placeholder, st.spinner("Saving Progress ..."): | |
if save_progress_placeholder.button("Save Progress", use_container_width=True): | |
# Update DB | |
update_db( | |
prj_id=st.session_state["project_number"], | |
page_nam="Response Curves", | |
file_nam="project_dct", | |
pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
schema=schema, | |
) | |
# Store the message details in session state | |
message_spinner_placeholder.success( | |
"Progress saved successfully!", icon="💾" | |
) | |
st.toast("Progress saved successfully!", icon="💾") | |
# Log message | |
log_message("info", "Progress saved successfully!", "Response Curves") | |
# Check if the RCS metadata file does not exist | |
if ( | |
st.session_state["project_dct"]["response_curves"]["original_metadata_file"] | |
is None | |
or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] | |
is None | |
): | |
# RCS metadata file does not exist. Generating new RCS data | |
generate_rcs_data() | |
# Log message | |
log_message( | |
"info", | |
"RCS metadata file does not exist. Generating new RCS data.", | |
"Response Curves", | |
) | |
# Load metadata files if they exist | |
original_data, modified_data = load_rcs_metadata_files() | |
# Retrieve the list of all channels names for specified Metrics and Panel | |
chanel_list_final = list(original_data[metrics_selected][panel_selected].keys()) | |
# Channel Selection | |
channel_selected = channel_col.selectbox( | |
"Channel", | |
sorted(chanel_list_final), | |
format_func=name_formating, | |
key="selected_channel_name_selectbox", | |
) | |
# Extract original channel data for the selected metric, panel, and channel | |
original_channel_data = original_data[metrics_selected][panel_selected][ | |
channel_selected | |
] | |
# Extract modified channel data for the same metric, panel, and channel | |
modified_channel_data = modified_data[metrics_selected][panel_selected][ | |
channel_selected | |
] | |
# X and Y values for plotting | |
x = original_channel_data["x"] | |
y = original_channel_data["y"] | |
# Scaling factor for X values and range for S-curve plotting | |
power = original_channel_data["power"] | |
x_plot = original_channel_data["x_plot"] | |
# Original S-curve parameters | |
K_orig = original_channel_data["K"] | |
b_orig = original_channel_data["b"] | |
a_orig = original_channel_data["a"] | |
x0_orig = original_channel_data["x0"] | |
# Modified S-curve parameters (user-adjusted) | |
K_mod = modified_channel_data["K"] | |
b_mod = modified_channel_data["b"] | |
a_mod = modified_channel_data["a"] | |
x0_mod = modified_channel_data["x0"] | |
# Create a scatter plot for the original data points | |
fig = px.scatter( | |
x=x, | |
y=y, | |
title="Original and Modified S-Curve Plot", | |
labels={"x": "Spends", "y": name_formating(metrics_selected)}, | |
) | |
# Add the modified S-curve trace | |
fig.add_trace( | |
go.Scatter( | |
x=x_plot, | |
y=s_curve( | |
np.array(x_plot) / 10**power, | |
K_mod, | |
b_mod, | |
a_mod, | |
x0_mod, | |
), | |
line=dict(color="red"), | |
name="Modified", | |
), | |
) | |
# Add the original S-curve trace | |
fig.add_trace( | |
go.Scatter( | |
x=x_plot, | |
y=s_curve( | |
np.array(x_plot) / 10**power, | |
K_orig, | |
b_orig, | |
a_orig, | |
x0_orig, | |
), | |
line=dict(color="rgba(0, 255, 0, 0.6)"), # Semi-transparent green | |
name="Original", | |
), | |
) | |
# Customize the layout of the plot | |
fig.update_layout( | |
title="Comparison of Original and Modified Response-Curves", | |
xaxis_title="Input (Clicks, Impressions, etc..)", | |
yaxis_title=name_formating(metrics_selected), | |
legend_title="Curve Type", | |
) | |
# Display s-curve | |
st.plotly_chart(fig, use_container_width=True) | |
# Calculate R-squared for the original curve | |
y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig) | |
r2_orig = r2_score(y, y_orig_pred) | |
# Calculate R-squared for the modified curve | |
y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod) | |
r2_mod = r2_score(y, y_mod_pred) | |
# Calculate the difference in R-squared | |
r2_diff = r2_mod - r2_orig | |
# Display R-squared metrics | |
st.write("## R-squared Comparison") | |
r2_col = st.columns(3) | |
r2_col[0].metric("R-squared (Original)", f"{r2_orig:.2f}") | |
r2_col[1].metric("R-squared (Modified)", f"{r2_mod:.2f}") | |
r2_col[2].metric("Difference in R-squared", f"{r2_diff:.2f}") | |
# Define unique keys for each parameter based on the selection | |
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" | |
# Initialize session state keys if they do not exist | |
if K_key not in st.session_state: | |
st.session_state[K_key] = K_mod | |
if b_key not in st.session_state: | |
st.session_state[b_key] = b_mod | |
if a_key not in st.session_state: | |
st.session_state[a_key] = a_mod | |
if x0_key not in st.session_state: | |
st.session_state[x0_key] = x0_mod | |
# RCS parameters input | |
rsc_ip_col = st.columns(4) | |
with rsc_ip_col[0]: | |
K_updated = st.number_input( | |
"K", | |
step=0.001, | |
min_value=0.000000, | |
format="%.6f", | |
on_change=modify_rcs_parameters, | |
args=(metrics_selected, panel_selected, channel_selected), | |
key=K_key, | |
) | |
with rsc_ip_col[1]: | |
b_updated = st.number_input( | |
"b", | |
step=0.001, | |
min_value=0.000000, | |
format="%.6f", | |
on_change=modify_rcs_parameters, | |
args=(metrics_selected, panel_selected, channel_selected), | |
key=b_key, | |
) | |
with rsc_ip_col[2]: | |
a_updated = st.number_input( | |
"a", | |
step=0.001, | |
min_value=0.000000, | |
format="%.6f", | |
on_change=modify_rcs_parameters, | |
args=(metrics_selected, panel_selected, channel_selected), | |
key=a_key, | |
) | |
with rsc_ip_col[3]: | |
x0_updated = st.number_input( | |
"x0", | |
step=0.001, | |
min_value=0.000000, | |
format="%.6f", | |
on_change=modify_rcs_parameters, | |
args=(metrics_selected, panel_selected, channel_selected), | |
key=x0_key, | |
) | |
# Create columns for Reset and Download buttons | |
reset_download_col = st.columns(2) | |
with reset_download_col[0]: | |
if st.button( | |
"Reset", | |
use_container_width=True, | |
): | |
reset_parameters( | |
metrics_selected, | |
panel_selected, | |
channel_selected, | |
original_channel_data, | |
) | |
# Log message | |
log_message( | |
"info", | |
f"METRIC: {name_formating(metrics_selected)} ; PANEL: {name_formating(panel_selected)}, CHANNEL: {name_formating(channel_selected)} has been reset to its original value.", | |
"Response Curves", | |
) | |
st.rerun() | |
with reset_download_col[1]: | |
# Provide a download button for the modified RCS data | |
try: | |
# Create JSON file for RCS data | |
json_data = create_json_file() | |
st.download_button( | |
label="Download", | |
data=json_data, | |
file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json", | |
mime="application/json", | |
use_container_width=True, | |
) | |
except: | |
# Download failed | |
pass | |
# Generate the DataFrame showing only non-matching parameters | |
updated_parm_df = updated_parm_gen( | |
original_data, modified_data, metrics_selected, panel_selected | |
) | |
# Display the DataFrame or show an informational message if no updates | |
if not updated_parm_df.empty: | |
st.write("## Parameter Comparison for Selected Metric and Panel") | |
st.dataframe(updated_parm_df, hide_index=True) | |
else: | |
st.info("No parameters are updated for the selected Metric and Panel") | |
except Exception as e: | |
# Capture the error details | |
exc_type, exc_value, exc_traceback = sys.exc_info() | |
error_message = "".join( | |
traceback.format_exception(exc_type, exc_value, exc_traceback) | |
) | |
# Log message | |
log_message("error", f"An error occurred: {error_message}.", "Response Curves") | |
# Display a warning message | |
st.warning( | |
"Oops! Something went wrong. Please try refreshing the tool or creating a new project.", | |
icon="⚠️", | |
) | |