##################################### # Packages & Dependencies ##################################### import param import panel as pn import torch import numpy as np import plotly.graph_objects as go from . import canvas from app_utils import styles import sys, os APP_PATH = os.path.dirname(os.path.dirname(__file__)) # Path to the digit-classifier-app directory sys.path.append(APP_PATH + '/model_training') # Imports from model_training import data_setup, model ##################################### # Plotly Panels ##################################### PLOTLY_CONFIGS = { 'displayModeBar': True, 'displaylogo': False, 'modeBarButtonsToRemove': ['autoScale', 'lasso', 'select', 'toImage', 'pan', 'zoom', 'zoomIn', 'zoomOut'] } class PlotPanels(param.Parameterized): ''' Contains all Plotly pane objects for the application. This includes the probability bar chart and the MNIST preprocessed image heat map. Args: canvas_info (param.ClassSelector): A Canvas class object to get the data URI of the drawn image. mod_path (str): The absolute path to the saved TinyVGG model. mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model. This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes ''' canvas_info = param.ClassSelector(class_ = canvas.Canvas) # Canvas object to get the data URI def __init__(self, mod_path: str, mod_kwargs: dict, **params): super().__init__(**params) self.class_labels = np.arange(0, 10) self.cnn_mod = model.TinyVGG(**mod_kwargs) self.cnn_mod.load_state_dict(torch.load(mod_path, map_location = 'cpu')) self.img_pane = pn.pane.Plotly( name = 'image_plot', config = PLOTLY_CONFIGS, sizing_mode = 'stretch_both', margin = 0, ) self.prob_pane = pn.pane.Plotly( name = 'prob_plot', config = PLOTLY_CONFIGS, sizing_mode = 'stretch_both', margin = 0 ) self.pred_txt = pn.pane.HTML( styles = {'margin':'0rem', 'color':styles.CLRS['pred_txt'], 'font-size':styles.FONTSIZES['pred_txt'], 'font-family':styles.FONTFAMILY} ) # Initialize plotly figures self._update_prediction() # Set up watchers thta update based on data URI changes self.canvas_info.param.watch(self._update_prediction, 'uri') def _update_prediction(self, *event): ''' Performs all prediction-related updates for the application. This function is connected to the URI parameter of canvas_info through a watcher. Any times the URI changes, a class prediction is immediately. Following this, the probability bar chart and model input heatmap are updated as well. ''' self._update_preprocessed_tensor() self._update_pred_txt() self._update_img_plot() self._update_prob_plot() def _update_preprocessed_tensor(self): ''' Transforms the data URI (string) from canvas_info into a preprocessed tensor. This is done by having it undergo the MNISt preprocessing pipeline (see mnist_preprocess in data_setup for details). Additionally, a prediction is made for the preprocessed tensor to get its class label. The correpsonding set of prediction probabilities are stored. ''' # Check if uri is non-empty if self.canvas_info.uri: self.input_img = data_setup.mnist_preprocess(self.canvas_info.uri) self.cnn_mod.eval() # Set CNN to eval & inference mode with torch.inference_mode(): pred_logits = self.cnn_mod(self.input_img.unsqueeze(0)) self.pred_probs = torch.softmax(pred_logits, dim = 1)[0].numpy() self.pred_label = np.argmax(self.pred_probs) else: self.input_img = torch.zeros((28, 28)) self.pred_probs = np.zeros(10) self.pred_label = None def _update_pred_txt(self): ''' Updates the prediction and probability HTML text to reflect the current data URI. ''' if self.canvas_info.uri: pred, prob = self.pred_label, f'{self.pred_probs[self.pred_label]:.3f}' else: pred, prob = 'N/A', 'N/A' self.pred_txt.object = f'''