|
import pickle |
|
|
|
import pandas as pd |
|
|
|
import shap |
|
|
|
import gradio as gr |
|
|
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
loaded_model = pickle.load(open("salar_xgb_team.pkl", "rb")) |
|
|
|
|
|
|
|
explainer = shap.Explainer(loaded_model) |
|
|
|
|
|
|
|
education_map = { |
|
|
|
"Preschool": 1, |
|
|
|
"1st-4th": 2, |
|
|
|
"5th-6th": 3, |
|
|
|
"7th-8th": 4, |
|
|
|
"9th": 5, |
|
|
|
"10th": 6, |
|
|
|
"11th": 7, |
|
|
|
"12th": 8, |
|
|
|
"HS-grad": 9, |
|
|
|
"Some-college": 10, |
|
|
|
"Assoc-voc": 11, |
|
|
|
"Assoc-acdm": 12, |
|
|
|
"Bachelors": 13, |
|
|
|
"Masters": 14, |
|
|
|
"Prof-school": 15, |
|
|
|
"Doctorate": 16 |
|
|
|
} |
|
|
|
|
|
|
|
def main_func(age, education_label, sex, capital_gain, capital_loss, hours_per_week): |
|
|
|
|
|
|
|
if age < 18 or age > 100 or hours_per_week < 1 or hours_per_week > 100: |
|
|
|
return {"β€50K": 0.0, ">50K": 0.0}, None, "β Invalid inputs. Please check your entries." |
|
|
|
|
|
|
|
education_num = education_map.get(education_label, 9) |
|
|
|
sex_binary = 0 if sex == "Male" else 1 |
|
|
|
|
|
|
|
new_row = pd.DataFrame({ |
|
|
|
'age': [age], |
|
|
|
'education-num': [education_num], |
|
|
|
'sex': [sex_binary], |
|
|
|
'capital-gain': [capital_gain], |
|
|
|
'capital-loss': [capital_loss], |
|
|
|
'hours-per-week': [hours_per_week] |
|
|
|
}) |
|
|
|
|
|
|
|
prob = loaded_model.predict_proba(new_row) |
|
|
|
shap_values = explainer(new_row) |
|
|
|
|
|
|
|
plt.figure(figsize=(8, 4)) |
|
|
|
shap.plots.bar(shap_values[0], max_display=6, show=False) |
|
|
|
plt.tight_layout() |
|
|
|
local_plot = plt.gcf() |
|
|
|
plt.close() |
|
|
|
|
|
|
|
pred_class = ">50K" if prob[0][1] > 0.5 else "β€50K" |
|
|
|
confidence = round(prob[0][1] if pred_class == ">50K" else prob[0][0], 2) |
|
|
|
interpretation = f"πΌ Prediction: **{pred_class}**\nConfidence: {confidence * 100:.2f}%" |
|
|
|
return { |
|
|
|
"β€50K": round(prob[0][0], 2), |
|
|
|
">50K": round(prob[0][1], 2) |
|
|
|
}, local_plot, interpretation |
|
|
|
|
|
|
|
title = "**Salary Predictor & SHAP Explainer** π°" |
|
|
|
description1 = "This app uses demographic and financial info to predict whether someone earns more than $50K annually." |
|
|
|
description2 = "Adjust the inputs and click **Analyze** to see prediction and SHAP feature contributions." |
|
|
|
with gr.Blocks(title=title) as demo: |
|
|
|
gr.Markdown(f"## {title}") |
|
|
|
gr.Markdown(description1) |
|
|
|
gr.Markdown("---") |
|
|
|
gr.Markdown(description2) |
|
|
|
gr.Markdown("---") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
|
|
age = gr.Number(label="Age", value=35, precision=0) |
|
|
|
education_label = gr.Dropdown( |
|
|
|
choices=list(education_map.keys()), |
|
|
|
label="Education Level", |
|
|
|
value="HS-grad" |
|
|
|
) |
|
|
|
sex = gr.Radio(["Male", "Female"], label="Sex") |
|
|
|
capital_gain = gr.Number(label="Capital Gain", value=0) |
|
|
|
capital_loss = gr.Number(label="Capital Loss", value=0) |
|
|
|
hours_per_week = gr.Slider(label="Hours Worked per Week", minimum=1, maximum=100, value=40) |
|
|
|
submit_btn = gr.Button("π Analyze") |
|
|
|
with gr.Column(scale=1): |
|
|
|
label = gr.Label(label="Predicted Probabilities") |
|
|
|
local_plot = gr.Plot(label="SHAP Feature Importance") |
|
|
|
result_text = gr.Textbox(label="Prediction Summary", lines=2) |
|
|
|
submit_btn.click( |
|
|
|
main_func, |
|
|
|
[age, education_label, sex, capital_gain, capital_loss, hours_per_week], |
|
|
|
[label, local_plot, result_text], |
|
|
|
api_name="Salary_Predictor" |
|
|
|
) |
|
|
|
gr.Markdown("### Try one of the following examples:") |
|
|
|
gr.Examples( |
|
|
|
examples=[ |
|
|
|
[28, "Some-college", "Male", 0, 0, 45], |
|
|
|
[52, "Masters", "Female", 7688, 0, 60], |
|
|
|
[35, "HS-grad", "Male", 0, 1902, 40] |
|
|
|
], |
|
|
|
inputs=[age, education_label, sex, capital_gain, capital_loss, hours_per_week], |
|
|
|
outputs=[label, local_plot, result_text], |
|
|
|
fn=main_func, |
|
|
|
cache_examples=True |
|
|
|
) |
|
|
|
gr.Markdown("---") |
|
|
|
gr.Markdown("Built with love by Team 3 for the 2025 AI Applications Project!") |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
gr.Markdown("π Thanks for using the Salary Predictor!") |
|
|
|
gr.Image( |
|
|
|
value="https://media.giphy.com/media/l0MYt5jPR6QX5pnqM/giphy.gif", |
|
|
|
label="", |
|
|
|
show_label=False, |
|
|
|
show_download_button=False, |
|
|
|
height=200 |
|
|
|
) |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|