import streamlit as st import pandas as pd import numpy as np import seaborn as sns import plotly.express as px import matplotlib.pyplot as plt from read_predictions_from_db import PredictionDBRead from read_daily_metrics_from_db import MetricsDBRead from sklearn.metrics import balanced_accuracy_score, accuracy_score import logging from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD, PERFORMANCE_THRESHOLD, CLASSIFIER_THRESHOLD) logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) def filter_prediction_data(data: pd.DataFrame): try: logging.info("Entering filter_prediction_data()") if data is None: raise Exception("Input Prediction Data frame in None") filtered_prediction_data = data.loc[(data['y_true_proba'] == 1) & (data['used_for_training'].astype("str").str.contains("_train")==False) & (data['used_for_training'].astype("str").str.contains("_excluded")==False) & (data['used_for_training'].astype("str").str.contains("_validation")==False) ].copy() logging.info("Exiting filter_prediction_data()") return filtered_prediction_data except Exception as e: logging.critical(f"Error in filter_prediction_data(): {e}") return None def get_adjusted_predictions(df): try: logging.info("Entering get_adjusted_predictions()") if df is None: raise Exception('Input Filtered Prediction Data Frame is None') df = df.copy() df.reset_index(drop=True, inplace=True) df.loc[df['y_pred_proba'] [data-testid="column"] { width: calc(33.3333% - 1rem) !important; flex: 1 1 calc(33.3333% - 1rem) !important; min-width: calc(33% - 1rem) !important; } ''', unsafe_allow_html=True) col1, col2= st.columns(2) with col1: metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy) with col2: metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy) col3, col4= st.columns(2) with col3: metric3 = st.metric(label="Accuracy", value=accuracy) with col4: metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy) col5, col6= st.columns(2) with col5: metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD) with col6: metric6 = st.metric(label="N Samples", value=n_samples) logging.info("Exiting display_kpis()") except Exception as e: logging.critical(f'Error in display_kpis(): {e}') st.error("Couldn't display KPIs") def plot_daily_metrics(metrics_df: pd.DataFrame): try: logging.info("Entering plot_daily_metrics()") st.write(" ") if metrics_df is None: raise Exception('Input Metrics Data Frame is None') metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date']) metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4) metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4) hover_data={'mean_balanced_accuracy_score': True, 'std_balanced_accuracy_score': False, 'mean_score_minus_std': True, 'mean_score_plus_std': True, 'evaluation_window_days': True, 'n_splits': True, 'sample_start_date': True, 'sample_end_date': True, 'sample_size_of_each_split': True} hover_labels = {'mean_balanced_accuracy_score': "Mean Score", 'mean_score_minus_std': "Mean Score - Stdev", 'mean_score_plus_std': "Mean Score + Stdev", 'evaluation_window_days': "Observation Window (Days)", 'sample_start_date': "Observation Window Start Date", 'sample_end_date': "Observation Window End Date", 'n_splits': "N Splits For Evaluation", 'sample_size_of_each_split': "Sample Size of Each Split"} fig = px.line(data_frame=metrics_df, x='evaluation_date', y='mean_balanced_accuracy_score', error_y='std_balanced_accuracy_score', title="Daily Balanced Accuracy", color_discrete_sequence=['black'], hover_data=hover_data, labels=hover_labels, markers=True) fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green", annotation_text=f"THRESHOLD", annotation_position="left top") fig.update_layout(dragmode='pan') fig.update_layout(margin=dict(l=0, r=0, t=110, b=10)) st.plotly_chart(fig, use_container_width=True) logging.info("Exiting plot_daily_metrics()") except Exception as e: logging.critical(f'Error in plot_daily_metrics(): {e}') st.error("Couldn't Plot Daily Model Metrics") def get_misclassified_classes(data): try: logging.info("Entering get_misclassified_classes()") if data is None: raise Exception("Input Prediction Data Frame is None") data = data.copy() data['match'] = (data['y_true'] == data['y_pred']).astype('int') y_pred_counts = data['y_pred'].value_counts() misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy() misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True) misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts() missing_classes = [i for i in y_pred_counts.index if i not in misclassifications.index] for i in missing_classes: misclassifications[i] = 0 misclassifications = misclassifications[y_pred_counts.index].copy() misclassifications /= y_pred_counts misclassifications.sort_values(ascending=False, inplace=True) logging.info("Exiting get_misclassified_classes()") return np.round(misclassifications, 2), misclassified_examples except Exception as e: logging.critical(f'Error in get_misclassified_classes(): {e}') return None, None def display_misclassified_examples(misclassified_classes, misclassified_examples): try: logging.info("Entering display_misclassified_examples()") st.write(" ") if misclassified_classes is None: raise Exception('Misclassified Classes Distribution Data Frame is None') if misclassified_examples is None: raise Exception('Misclassified Examples Data Frame is None') fig, ax = plt.subplots(figsize=(10, 4.5)) misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage") plt.yticks([]) plt.xlabel("") ax.bar_label(ax.containers[0]); st.pyplot(fig) st.markdown("Misclassified examples", unsafe_allow_html=True) st.dataframe(misclassified_examples, hide_index=True) st.markdown( """ """, unsafe_allow_html=True ) logging.info("Exiting display_misclassified_examples()") except Exception as e: logging.critical(f'Error in display_misclassified_examples(): {e}') st.error("Couldn't display Misclassification Data") def classification_model_monitor(): try: # st.write('

Classification Model Monitor (out of service)

', unsafe_allow_html=True) st.write('

Classification Model Monitor

', unsafe_allow_html=True) prediction_db = PredictionDBRead() metrics_db = MetricsDBRead() # Read Prediction Data From DB prediction_data = prediction_db.read_predictions_from_db() # Filter Prediction Data filtered_prediction_data = filter_prediction_data(prediction_data) # Get Adjusted Prediction Data adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data) # Display KPIs display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data) # Read Daily Metrics From DB metrics_df = metrics_db.read_metrics_from_db() # Display daily Metrics Line Plot plot_daily_metrics(metrics_df) # Get misclassified class distribution and misclassified examples from Prediction Data misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data) # Display Misclassification Data display_misclassified_examples(misclassified_classes, misclassified_examples) st.markdown( """ """, unsafe_allow_html=True ) except Exception as e: logging.critical(f"Error in classification_model_monitor(): {e}") st.error("Unexpected Error. Couldn't display Classification Model Monitor")