# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This script can be used to visualize the errors made by a (duplex) TN system. More specifically, after running the evaluation script `duplex_text_normalization_test.py`, a log file containing info about the errors will be generated. The location of this file is determined by the argument `inference.errors_log_fp`. After that, we can use this script to generate a HTML visualization. USAGE Example: # python analyze_errors.py \ --errors_log_fp=PATH_TO_ERRORS_LOG_FILE_PATH \ --visualization_fp=PATH_TO_VISUALIZATION_FILE_PATH """ from argparse import ArgumentParser from typing import List from nemo.collections.nlp.data.text_normalization import constants # Longest Common Subsequence def lcs(X, Y): """ Function for finding the longest common subsequence between two lists. In this script, this function is particular used for aligning between the ground-truth output string and the predicted string (for visualization purpose). Args: X: a list Y: a list Returns: a list which is the longest common subsequence between X and Y """ m, n = len(X), len(Y) L = [[0 for x in range(n + 1)] for x in range(m + 1)] # Following steps build L[m+1][n+1] in bottom up fashion. Note # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1] for i in range(m + 1): for j in range(n + 1): if i == 0 or j == 0: L[i][j] = 0 elif X[i - 1] == Y[j - 1]: L[i][j] = L[i - 1][j - 1] + 1 else: L[i][j] = max(L[i - 1][j], L[i][j - 1]) # Following code is used to print LCS index = L[m][n] # Create a character array to store the lcs string lcs = [''] * (index + 1) lcs[index] = '' # Start from the right-most-bottom-most corner and # one by one store characters in lcs[] i = m j = n while i > 0 and j > 0: # If current character in X[] and Y are same, then # current character is part of LCS if X[i - 1] == Y[j - 1]: lcs[index - 1] = X[i - 1] i -= 1 j -= 1 index -= 1 # If not same, then find the larger of two and # go in the direction of larger value elif L[i - 1][j] > L[i][j - 1]: i -= 1 else: j -= 1 return lcs[:-1] # Classes class ErrorCase: """ This class represents an error case Args: _input: Original input string target: Ground-truth target string pred: Predicted string mode: A string indicates the mode (i.e., constants.ITN_MODE or constants.TN_MODE) """ def __init__(self, _input: str, target: str, pred: str, classes: str, mode: str): self._input = _input self.target = target self.pred = pred self.mode = mode self.classes = classes # Tokens self.target_tokens = self.target.split(' ') self.pred_tokens = self.pred.split(' ') # LCS lcs_tokens = lcs(self.target_tokens, self.pred_tokens) target_tokens_hightlight = [False] * len(self.target_tokens) pred_tokens_hightlight = [False] * len(self.pred_tokens) target_idx, pred_idx = 0, 0 for token in lcs_tokens: while self.target_tokens[target_idx] != token: target_idx += 1 while self.pred_tokens[pred_idx] != token: pred_idx += 1 target_tokens_hightlight[target_idx] = True pred_tokens_hightlight[pred_idx] = True target_idx += 1 pred_idx += 1 # Spans self.target_spans = self.get_spans(target_tokens_hightlight) self.pred_spans = self.get_spans(pred_tokens_hightlight) # Determine unhighlighted target spans unhighlighted_target_spans = [] for ix, t in enumerate(self.target_spans): if not t[-1]: unhighlighted_target_spans.append((ix, t)) # Determine unhighlighted pred spans unhighlighted_pred_spans = [] for ix, t in enumerate(self.pred_spans): if not t[-1]: unhighlighted_pred_spans.append((ix, t)) @classmethod def from_lines(cls, lines: List[str], mode: str): """ This method returns an instance of ErrorCase from raw string lines. Args: lines: A list of raw string lines for the error case. mode: A string indicates the mode (i.e., constants.ITN_MODE or constants.TN_MODE) Returns: an instance of ErrorCase. """ for line in lines: if line.startswith('Original Input'): _input = line[line.find(':') + 1 :].strip() elif line.startswith('Predicted Str'): pred = line[line.find(':') + 1 :].strip() elif line.startswith('Ground-Truth'): target = line[line.find(':') + 1 :].strip() elif line.startswith('Ground Classes'): classes = line[line.find(':') + 1 :].strip() return cls(_input, target, pred, classes, mode) def get_html(self): """ This method returns a HTML string representing this error case instance. Returns: a string contains the HTML representing this error case instance. """ html_str = '' # Input input_form = 'Written' if self.mode == constants.TN_MODE else 'Spoken' padding_multiplier = 1 if self.mode == constants.TN_MODE else 2 padding_spaces = ''.join([' '] * padding_multiplier) input_str = f'[Input ({input_form})]{padding_spaces}: {self._input}\n' html_str += input_str + ' ' # Target target_html = self.get_spans_html(self.target_spans, self.target_tokens) target_form = 'Spoken' if self.mode == constants.TN_MODE else 'Written' target_str = f'[Target ({target_form})]: {target_html}\n' html_str += target_str + ' ' # Pred pred_html = self.get_spans_html(self.pred_spans, self.pred_tokens) padding_multiplier = 10 if self.mode == constants.TN_MODE else 11 padding_spaces = ''.join([' '] * padding_multiplier) pred_str = f'[Prediction]{padding_spaces}: {pred_html}\n' html_str += pred_str + ' ' # Classes padding_multiplier = 15 if self.mode == constants.TN_MODE else 16 padding_spaces = ''.join([' '] * padding_multiplier) class_str = f'[Classes]{padding_spaces}: {self.classes}\n' html_str += class_str + ' ' # Space html_str += '\n' return html_str def get_spans(self, tokens_hightlight): """ This method extracts the list of spans. Args: tokens_hightlight: A list of boolean values where each value indicates whether a token needs to be hightlighted. Returns: spans: A list of spans. Each span is represented by a tuple of 3 elements: (1) Start Index (2) End Index (3) A boolean value indicating whether the span needs to be hightlighted. """ spans, nb_tokens = [], len(tokens_hightlight) cur_start_idx, cur_bool_val = 0, tokens_hightlight[0] for idx in range(nb_tokens): if idx == nb_tokens - 1: if tokens_hightlight[idx] != cur_bool_val: spans.append((cur_start_idx, nb_tokens - 2, cur_bool_val)) spans.append((nb_tokens - 1, nb_tokens - 1, tokens_hightlight[idx])) else: spans.append((cur_start_idx, nb_tokens - 1, cur_bool_val)) else: if tokens_hightlight[idx] != cur_bool_val: spans.append((cur_start_idx, idx - 1, cur_bool_val)) cur_start_idx, cur_bool_val = idx, tokens_hightlight[idx] return spans def get_spans_html(self, spans, tokens): """ This method generates a HTML string for a string sequence from its spans. Args: spans: A list of contiguous spans in a sequence. Each span is represented by a tuple of 3 elements: (1) Start Index (2) End Index (3) A boolean value indicating whether the span needs to be hightlighted. tokens: All tokens in the sequence Returns: html_str: A HTML string for the string sequence. """ html_str = '' for start, end, type in spans: color = 'red' if type else 'black' span_tokens = tokens[start : end + 1] span_str = '{} '.format(color, ' '.join(span_tokens)) html_str += span_str return html_str # Main function for analysis def analyze(errors_log_fp: str, visualization_fp: str): """ This method generates a HTML visualization of the error cases logged in a log file. Args: errors_log_fp: Path to the error log file visualization_fp: Path to the output visualization file """ # Read lines from errors log with open(errors_log_fp, 'r', encoding='utf-8') as f: lines = f.readlines() # Process lines tn_error_cases, itn_error_cases = [], [] for ix in range(0, len(lines), 8): mode_line = lines[ix] info_lines = lines[ix + 1 : ix + 7] # Append new error case if mode_line.startswith('Forward Problem'): mode = constants.TN_MODE tn_error_cases.append(ErrorCase.from_lines(info_lines, mode)) elif mode_line.startswith('Backward Problem'): mode = constants.ITN_MODE itn_error_cases.append(ErrorCase.from_lines(info_lines, mode)) # Basic stats print('---- Text Normalization ----') print('Number of TN errors: {}'.format(len(tn_error_cases))) print('---- Inverse Text Normalization ---- ') print('Number of ITN errors: {}'.format(len(itn_error_cases))) # Produce a visualization with open(visualization_fp, 'w+', encoding='utf-8') as f: # Appendix f.write('Appendix') f.write('Text Normalization Analysis.') f.write('Inverse Text Normalization Analysis.') # TN Section f.write('