File size: 11,661 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to visualize the errors made by a (duplex) TN system.
More specifically, after running the evaluation script `duplex_text_normalization_test.py`,
a log file containing info about the errors will be generated. The location of this file
is determined by the argument `inference.errors_log_fp`. After that, we can use this
script to generate a HTML visualization.
USAGE Example:
# python analyze_errors.py \
--errors_log_fp=PATH_TO_ERRORS_LOG_FILE_PATH \
--visualization_fp=PATH_TO_VISUALIZATION_FILE_PATH
"""
from argparse import ArgumentParser
from typing import List
from nemo.collections.nlp.data.text_normalization import constants
# Longest Common Subsequence
def lcs(X, Y):
""" Function for finding the longest common subsequence between two lists.
In this script, this function is particular used for aligning between the
ground-truth output string and the predicted string (for visualization purpose).
Args:
X: a list
Y: a list
Returns: a list which is the longest common subsequence between X and Y
"""
m, n = len(X), len(Y)
L = [[0 for x in range(n + 1)] for x in range(m + 1)]
# Following steps build L[m+1][n+1] in bottom up fashion. Note
# that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1]
for i in range(m + 1):
for j in range(n + 1):
if i == 0 or j == 0:
L[i][j] = 0
elif X[i - 1] == Y[j - 1]:
L[i][j] = L[i - 1][j - 1] + 1
else:
L[i][j] = max(L[i - 1][j], L[i][j - 1])
# Following code is used to print LCS
index = L[m][n]
# Create a character array to store the lcs string
lcs = [''] * (index + 1)
lcs[index] = ''
# Start from the right-most-bottom-most corner and
# one by one store characters in lcs[]
i = m
j = n
while i > 0 and j > 0:
# If current character in X[] and Y are same, then
# current character is part of LCS
if X[i - 1] == Y[j - 1]:
lcs[index - 1] = X[i - 1]
i -= 1
j -= 1
index -= 1
# If not same, then find the larger of two and
# go in the direction of larger value
elif L[i - 1][j] > L[i][j - 1]:
i -= 1
else:
j -= 1
return lcs[:-1]
# Classes
class ErrorCase:
"""
This class represents an error case
Args:
_input: Original input string
target: Ground-truth target string
pred: Predicted string
mode: A string indicates the mode (i.e., constants.ITN_MODE or constants.TN_MODE)
"""
def __init__(self, _input: str, target: str, pred: str, classes: str, mode: str):
self._input = _input
self.target = target
self.pred = pred
self.mode = mode
self.classes = classes
# Tokens
self.target_tokens = self.target.split(' ')
self.pred_tokens = self.pred.split(' ')
# LCS
lcs_tokens = lcs(self.target_tokens, self.pred_tokens)
target_tokens_hightlight = [False] * len(self.target_tokens)
pred_tokens_hightlight = [False] * len(self.pred_tokens)
target_idx, pred_idx = 0, 0
for token in lcs_tokens:
while self.target_tokens[target_idx] != token:
target_idx += 1
while self.pred_tokens[pred_idx] != token:
pred_idx += 1
target_tokens_hightlight[target_idx] = True
pred_tokens_hightlight[pred_idx] = True
target_idx += 1
pred_idx += 1
# Spans
self.target_spans = self.get_spans(target_tokens_hightlight)
self.pred_spans = self.get_spans(pred_tokens_hightlight)
# Determine unhighlighted target spans
unhighlighted_target_spans = []
for ix, t in enumerate(self.target_spans):
if not t[-1]:
unhighlighted_target_spans.append((ix, t))
# Determine unhighlighted pred spans
unhighlighted_pred_spans = []
for ix, t in enumerate(self.pred_spans):
if not t[-1]:
unhighlighted_pred_spans.append((ix, t))
@classmethod
def from_lines(cls, lines: List[str], mode: str):
"""
This method returns an instance of ErrorCase from raw string lines.
Args:
lines: A list of raw string lines for the error case.
mode: A string indicates the mode (i.e., constants.ITN_MODE or constants.TN_MODE)
Returns: an instance of ErrorCase.
"""
for line in lines:
if line.startswith('Original Input'):
_input = line[line.find(':') + 1 :].strip()
elif line.startswith('Predicted Str'):
pred = line[line.find(':') + 1 :].strip()
elif line.startswith('Ground-Truth'):
target = line[line.find(':') + 1 :].strip()
elif line.startswith('Ground Classes'):
classes = line[line.find(':') + 1 :].strip()
return cls(_input, target, pred, classes, mode)
def get_html(self):
"""
This method returns a HTML string representing this error case instance.
Returns: a string contains the HTML representing this error case instance.
"""
html_str = ''
# Input
input_form = 'Written' if self.mode == constants.TN_MODE else 'Spoken'
padding_multiplier = 1 if self.mode == constants.TN_MODE else 2
padding_spaces = ''.join([' '] * padding_multiplier)
input_str = f'<b>[Input ({input_form})]{padding_spaces}</b>: {self._input}</br>\n'
html_str += input_str + ' '
# Target
target_html = self.get_spans_html(self.target_spans, self.target_tokens)
target_form = 'Spoken' if self.mode == constants.TN_MODE else 'Written'
target_str = f'<b>[Target ({target_form})]</b>: {target_html}</br>\n'
html_str += target_str + ' '
# Pred
pred_html = self.get_spans_html(self.pred_spans, self.pred_tokens)
padding_multiplier = 10 if self.mode == constants.TN_MODE else 11
padding_spaces = ''.join([' '] * padding_multiplier)
pred_str = f'<b>[Prediction]{padding_spaces}</b>: {pred_html}</br>\n'
html_str += pred_str + ' '
# Classes
padding_multiplier = 15 if self.mode == constants.TN_MODE else 16
padding_spaces = ''.join([' '] * padding_multiplier)
class_str = f'<b>[Classes]{padding_spaces}</b>: {self.classes}</br>\n'
html_str += class_str + ' '
# Space
html_str += '</br>\n'
return html_str
def get_spans(self, tokens_hightlight):
"""
This method extracts the list of spans.
Args:
tokens_hightlight: A list of boolean values where each value indicates whether a token needs to be hightlighted.
Returns:
spans: A list of spans. Each span is represented by a tuple of 3 elements: (1) Start Index (2) End Index (3) A boolean value indicating whether the span needs to be hightlighted.
"""
spans, nb_tokens = [], len(tokens_hightlight)
cur_start_idx, cur_bool_val = 0, tokens_hightlight[0]
for idx in range(nb_tokens):
if idx == nb_tokens - 1:
if tokens_hightlight[idx] != cur_bool_val:
spans.append((cur_start_idx, nb_tokens - 2, cur_bool_val))
spans.append((nb_tokens - 1, nb_tokens - 1, tokens_hightlight[idx]))
else:
spans.append((cur_start_idx, nb_tokens - 1, cur_bool_val))
else:
if tokens_hightlight[idx] != cur_bool_val:
spans.append((cur_start_idx, idx - 1, cur_bool_val))
cur_start_idx, cur_bool_val = idx, tokens_hightlight[idx]
return spans
def get_spans_html(self, spans, tokens):
"""
This method generates a HTML string for a string sequence from its spans.
Args:
spans: A list of contiguous spans in a sequence. Each span is represented by a tuple of 3 elements: (1) Start Index (2) End Index (3) A boolean value indicating whether the span needs to be hightlighted.
tokens: All tokens in the sequence
Returns:
html_str: A HTML string for the string sequence.
"""
html_str = ''
for start, end, type in spans:
color = 'red' if type else 'black'
span_tokens = tokens[start : end + 1]
span_str = '<span style="color:{}">{}</span> '.format(color, ' '.join(span_tokens))
html_str += span_str
return html_str
# Main function for analysis
def analyze(errors_log_fp: str, visualization_fp: str):
"""
This method generates a HTML visualization of the error cases logged in a log file.
Args:
errors_log_fp: Path to the error log file
visualization_fp: Path to the output visualization file
"""
# Read lines from errors log
with open(errors_log_fp, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Process lines
tn_error_cases, itn_error_cases = [], []
for ix in range(0, len(lines), 8):
mode_line = lines[ix]
info_lines = lines[ix + 1 : ix + 7]
# Append new error case
if mode_line.startswith('Forward Problem'):
mode = constants.TN_MODE
tn_error_cases.append(ErrorCase.from_lines(info_lines, mode))
elif mode_line.startswith('Backward Problem'):
mode = constants.ITN_MODE
itn_error_cases.append(ErrorCase.from_lines(info_lines, mode))
# Basic stats
print('---- Text Normalization ----')
print('Number of TN errors: {}'.format(len(tn_error_cases)))
print('---- Inverse Text Normalization ---- ')
print('Number of ITN errors: {}'.format(len(itn_error_cases)))
# Produce a visualization
with open(visualization_fp, 'w+', encoding='utf-8') as f:
# Appendix
f.write('Appendix</br>')
f.write('<a href="#tn_section">Text Normalization Analysis.</a></br>')
f.write('<a href="#itn_section">Inverse Text Normalization Analysis.</a>')
# TN Section
f.write('<h2 id="tn_section">Text Normalization</h2>\n')
for errorcase in tn_error_cases:
f.write(errorcase.get_html())
# ITN Section
f.write('<h2 id="itn_section">Inverse Text Normalization</h2>\n')
for errorcase in itn_error_cases:
f.write(errorcase.get_html())
if __name__ == '__main__':
# Parse argument
parser = ArgumentParser()
parser.add_argument('--errors_log_fp', help='Path to the error log file', required=True)
parser.add_argument('--visualization_fp', help='Path to the output visualization file', required=True)
args = parser.parse_args()
analyze(args.errors_log_fp, args.visualization_fp)
|