import streamlit as st import pandas as pd import pickle import numpy as np import os from openai import OpenAI import utils as ut if "GROQ_API_KEY" in os.environ: api_key = os.environ.get("GROQ_API_KEY") else: api_key = st.secrets["GROQ_API_KEY"] client = OpenAI( base_url="https://api.groq.com/openai/v1", api_key=api_key ) def load_model(file_name): with open(file_name, 'rb') as file: return pickle.load(file) xgb_model = load_model('xgb_model.pkl') naive_bayes_model = load_model('nb_model.pkl') random_forest_model = load_model('rf_model.pkl') decision_tree_model = load_model('dt_model.pkl') knn_model = load_model('knn_model.pkl') def prepare_input_data(credit_score, location, gender, age, tenure, balance, num_products, has_credit_card, is_active_member, estimated_salary): input_dict = { 'CreditScore': credit_score, 'Age': age, 'Tenure': tenure, 'Balance': balance, 'NumOfProducts': num_products, 'HasCrCard': has_credit_card, 'IsActiveMember': is_active_member, 'EstimatedSalary': estimated_salary, 'Geography_France': 1 if location == 'France' else 0, 'Geography_Germany': 1 if location == 'Germany' else 0, 'Geography_Spain': 1 if location == 'Spain' else 0, 'Gender_Male': 1 if gender == 'Male' else 0, 'Gender_Female': 1 if gender == 'Female' else 0 } input_df = pd.DataFrame([input_dict]) return input_df, input_dict def make_prediction(input_df, input_dict): probabilities = { 'XGBoost': xgb_model.predict_proba(input_df)[0, 1], 'Naive Bayes': naive_bayes_model.predict_proba(input_df)[0, 1], 'Random Forest': random_forest_model.predict_proba(input_df)[0, 1], 'Decision Tree': decision_tree_model.predict_proba(input_df)[0, 1], 'K-Nearest Neighbors': knn_model.predict_proba(input_df)[0, 1], } avg_probability = np.mean(list(probabilities.values())) col1, col2 = st.columns(2) with col1: fig = ut.create_guage_chart(avg_probability) st.plotly_chart(fig, use_container_width=True) st.write(f"The customer has a {avg_probability:.2f}% probability of churning.") with col2: fig = ut.create_model_probability_chart(probabilities) st.plotly_chart(fig, use_container_width=True) st.markdown("### Model Probabilities") for model, prob in probabilities.items(): st.markdown(f"{model}: {prob:.2f}") st.markdown(f"### Average Probability: {avg_probability:.2f}") return avg_probability def explain_prediction(probability, input_dict, surname): prompt = f"""You are an expert data scientist at a bank, where you specialize in interpreting and explaining predictions of machine learning models. A customer with the name {surname} has been assessed as having a {round(probability * 100, 1)}% likelihood of churning based on their profile and engagement. Here is the customer's information: {input_dict} Here are the machine learning model's top 10 most influential features affecting churn: Feature | Importance: ------------------------------- NumOfProducts | 0.323888 IsActiveMember | 0.164146 Age | 0.109550 Geography_Germany | 0.091373 Balance | 0.052786 Geography_France | 0.046463 Gender_Female | 0.045283 Geography_Spain | 0.036855 CreditScore | 0.035005 EstimatedSalary | 0.032655 HasCrCard | 0.031940 Tenure | 0.030054 Gender_Male | 0.000000 {pd.set_option('display.max_columns', None)} Here are the summary statistics for churned customers: {df[df['Exited'] == 1].describe()} Here are the summary statistics for non-churned customers: {df[df['Exited'] == 0].describe()} Based on the customer’s probability of churning: - If the probability is above 40%, generate a brief 3-sentence explanation outlining why the customer is at risk of churning. - If the probability is below 40%, generate a 3-sentence explanation of why the customer may not be at risk of churning. The output should only be the explanation itself, based on the customer's information, the summary statistics of churned and non-churned customers, and the most influential features, without mentioning probability, model, or feature names. No extra text or summaries are needed. """ raw_response = client.chat.completions.create( model="llama-3.2-3b-preview", messages=[{"role": "user", "content": prompt}], temperature=0.5 ) return raw_response.choices[0].message.content st.title("Customer Churn Predictor") df = pd.read_csv('churn.csv') customers = [f"{row['CustomerId']} - {row['Surname']}" for _, row in df.iterrows()] selected_customer_option = st.selectbox("Select a customer", customers) if selected_customer_option: selected_customer_id = selected_customer_option.split(' - ')[0] selected_customer_surname = selected_customer_option.split(' - ')[1] selected_customer = df.loc[df["CustomerId"] == int(selected_customer_id)].iloc[0] col1, col2 = st.columns(2) with col1: credit_score = st.number_input( "Credit Score", min_value=300, max_value=850, value=selected_customer["CreditScore"] ) location = st.selectbox( "Location", ["France", "Spain", "Germany"], index=["France", "Spain", "Germany"].index(selected_customer["Geography"]) ) gender = st.radio( "Gender", ["Male", "Female"], index=0 if selected_customer["Gender"] == "Male" else 1 ) age = st.number_input( "Age", min_value=18, max_value=100, value=int(selected_customer["Age"]) ) tenure = st.number_input( "Tenure (years)", min_value=0, max_value=50, value=int(selected_customer["Tenure"]) ) with col2: balance = st.number_input( "Balance", min_value=0.0, value=float(selected_customer["Balance"]) ) num_products = st.number_input( "Number of Products", min_value=1, max_value=10, value=int(selected_customer["NumOfProducts"]) ) has_credit_card = st.checkbox( "Has Credit Card", value=bool(selected_customer["HasCrCard"]) ) is_active_member = st.checkbox( "Active Member", value=bool(selected_customer["IsActiveMember"]) ) estimated_salary = st.number_input( "Estimated Salary", min_value=0.0, value=float(selected_customer["EstimatedSalary"]) ) input_df, input_dict = prepare_input_data(credit_score, location, gender, age, tenure, balance, num_products, has_credit_card, is_active_member, estimated_salary) avg_probability = make_prediction(input_df, input_dict) explanation = explain_prediction(avg_probability, input_dict, selected_customer_surname) st.markdown("---") st.subheader("Explanation of the Prediction") st.markdown(explanation)