File size: 11,438 Bytes
4721618 77f4a7d 7b36c5c 4721618 fa315b4 4721618 684906e 4721618 5ebd4ab 4721618 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import streamlit as st
import pandas as pd
import numpy as np
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt
from read_predictions_from_db import PredictionDBRead
from read_daily_metrics_from_db import MetricsDBRead
from sklearn.metrics import balanced_accuracy_score, accuracy_score
import logging
from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD,
PERFORMANCE_THRESHOLD,
CLASSIFIER_THRESHOLD)
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
def filter_prediction_data(data: pd.DataFrame):
try:
logging.info("Entering filter_prediction_data()")
if data is None:
raise Exception("Input Prediction Data frame in None")
filtered_prediction_data = data.loc[(data['y_true_proba'] == 1) & (data['used_for_training'].astype("str").str.contains("_train")==False) &
(data['used_for_training'].astype("str").str.contains("_excluded")==False) &
(data['used_for_training'].astype("str").str.contains("_validation")==False)
].copy()
logging.info("Exiting filter_prediction_data()")
return filtered_prediction_data
except Exception as e:
logging.critical(f"Error in filter_prediction_data(): {e}")
return None
def get_adjusted_predictions(df):
try:
logging.info("Entering get_adjusted_predictions()")
if df is None:
raise Exception('Input Filtered Prediction Data Frame is None')
df = df.copy()
df.reset_index(drop=True, inplace=True)
df.loc[df['y_pred_proba']<CLASSIFIER_ADJUSTMENT_THRESHOLD, 'y_pred'] = 'NATION'
# df.loc[(df['text'].str.contains('Pakistan')) & (df['y_pred'] == 'NATION'), 'y_pred'] = 'WORLD'
# df.loc[(df['text'].str.contains('Zodiac Sign', case=False)) | (df['text'].str.contains('Horoscope', case=False)), 'y_pred'] = 'SCIENCE'
logging.info("Exiting get_adjusted_predictions()")
return df
except Exception as e:
logging.info(f"Error in get_adjusted_predictions(): {e}")
return None
def display_kpis(data: pd.DataFrame, adj_data: pd.DataFrame):
try:
logging.info("Entering display_kpis()")
if data is None:
raise Exception("Input Prediction Data frame in None")
if adj_data is None:
raise Exception('Input Adjusted Data frame is None')
n_samples = len(data)
balanced_accuracy = np.round(balanced_accuracy_score(data['y_true'], data['y_pred']), 4)
accuracy = np.round(accuracy_score(data['y_true'], data['y_pred']), 4)
adj_balanced_accuracy = np.round(balanced_accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
adj_accuracy = np.round(accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
st.write('''<style>
[data-testid="column"] {
width: calc(33.3333% - 1rem) !important;
flex: 1 1 calc(33.3333% - 1rem) !important;
min-width: calc(33% - 1rem) !important;
}
</style>''',
unsafe_allow_html=True)
col1, col2= st.columns(2)
with col1:
metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy)
with col2:
metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy)
col3, col4= st.columns(2)
with col3:
metric3 = st.metric(label="Accuracy", value=accuracy)
with col4:
metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy)
col5, col6= st.columns(2)
with col5:
metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD)
with col6:
metric6 = st.metric(label="N Samples", value=n_samples)
logging.info("Exiting display_kpis()")
except Exception as e:
logging.critical(f'Error in display_kpis(): {e}')
st.error("Couldn't display KPIs")
def plot_daily_metrics(metrics_df: pd.DataFrame):
try:
logging.info("Entering plot_daily_metrics()")
st.write(" ")
if metrics_df is None:
raise Exception('Input Metrics Data Frame is None')
metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date'])
metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4)
metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4)
hover_data={'mean_balanced_accuracy_score': True,
'std_balanced_accuracy_score': False,
'mean_score_minus_std': True,
'mean_score_plus_std': True,
'evaluation_window_days': True,
'n_splits': True,
'sample_start_date': True,
'sample_end_date': True,
'sample_size_of_each_split': True}
hover_labels = {'mean_balanced_accuracy_score': "Mean Score",
'mean_score_minus_std': "Mean Score - Stdev",
'mean_score_plus_std': "Mean Score + Stdev",
'evaluation_window_days': "Observation Window (Days)",
'sample_start_date': "Observation Window Start Date",
'sample_end_date': "Observation Window End Date",
'n_splits': "N Splits For Evaluation",
'sample_size_of_each_split': "Sample Size of Each Split"}
fig = px.line(data_frame=metrics_df, x='evaluation_date',
y='mean_balanced_accuracy_score',
error_y='std_balanced_accuracy_score',
title="Daily Balanced Accuracy",
color_discrete_sequence=['black'],
hover_data=hover_data, labels=hover_labels, markers=True)
fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green",
annotation_text=f"<b>THRESHOLD</b>",
annotation_position="left top")
fig.update_layout(dragmode='pan')
fig.update_layout(margin=dict(l=0, r=0, t=110, b=10))
st.plotly_chart(fig, use_container_width=True)
logging.info("Exiting plot_daily_metrics()")
except Exception as e:
logging.critical(f'Error in plot_daily_metrics(): {e}')
st.error("Couldn't Plot Daily Model Metrics")
def get_misclassified_classes(data):
try:
logging.info("Entering get_misclassified_classes()")
if data is None:
raise Exception("Input Prediction Data Frame is None")
data = data.copy()
data['match'] = (data['y_true'] == data['y_pred']).astype('int')
y_pred_counts = data['y_pred'].value_counts()
misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy()
misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True)
misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts()
missing_classes = [i for i in y_pred_counts.index if i not in misclassifications.index]
for i in missing_classes:
misclassifications[i] = 0
misclassifications = misclassifications[y_pred_counts.index].copy()
misclassifications /= y_pred_counts
misclassifications.sort_values(ascending=False, inplace=True)
logging.info("Exiting get_misclassified_classes()")
return np.round(misclassifications, 2), misclassified_examples
except Exception as e:
logging.critical(f'Error in get_misclassified_classes(): {e}')
return None, None
def display_misclassified_examples(misclassified_classes, misclassified_examples):
try:
logging.info("Entering display_misclassified_examples()")
st.write(" ")
if misclassified_classes is None:
raise Exception('Misclassified Classes Distribution Data Frame is None')
if misclassified_examples is None:
raise Exception('Misclassified Examples Data Frame is None')
fig, ax = plt.subplots(figsize=(10, 4.5))
misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage")
plt.yticks([])
plt.xlabel("")
ax.bar_label(ax.containers[0]);
st.pyplot(fig)
st.markdown("<b>Misclassified examples</b>", unsafe_allow_html=True)
st.dataframe(misclassified_examples, hide_index=True)
st.markdown(
"""
<style>
[data-testid="stElementToolbar"] {
display: none;
}
</style>
""",
unsafe_allow_html=True
)
logging.info("Exiting display_misclassified_examples()")
except Exception as e:
logging.critical(f'Error in display_misclassified_examples(): {e}')
st.error("Couldn't display Misclassification Data")
def classification_model_monitor():
try:
# st.write('<h4>Classification Model Monitor<span style="color: red;"> (out of service)</span></h4>', unsafe_allow_html=True)
st.write('<h4>Classification Model Monitor</h4>', unsafe_allow_html=True)
prediction_db = PredictionDBRead()
metrics_db = MetricsDBRead()
# Read Prediction Data From DB
prediction_data = prediction_db.read_predictions_from_db()
# Filter Prediction Data
filtered_prediction_data = filter_prediction_data(prediction_data)
# Get Adjusted Prediction Data
adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data)
# Display KPIs
display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data)
# Read Daily Metrics From DB
metrics_df = metrics_db.read_metrics_from_db()
# Display daily Metrics Line Plot
plot_daily_metrics(metrics_df)
# Get misclassified class distribution and misclassified examples from Prediction Data
misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data)
# Display Misclassification Data
display_misclassified_examples(misclassified_classes, misclassified_examples)
st.markdown(
"""<style>
[data-testid="stMetricValue"] {
font-size: 25px;
}
</style>
""", unsafe_allow_html=True
)
except Exception as e:
logging.critical(f"Error in classification_model_monitor(): {e}")
st.error("Unexpected Error. Couldn't display Classification Model Monitor")
|