maxspad commited on
Commit
0568b17
·
1 Parent(s): e99bd97

added explanations

Browse files
Files changed (2) hide show
  1. app.py +78 -6
  2. requirements.txt +4 -1
app.py CHANGED
@@ -3,7 +3,10 @@ import transformers as tf
3
  import pandas as pd
4
  from datetime import datetime
5
  from plotly import graph_objects as go
6
-
 
 
 
7
  from overview import NQDOverview
8
 
9
  import torch
@@ -94,10 +97,79 @@ with st.form('comment_form'):
94
  st.experimental_rerun()
95
 
96
  results = run_models(models_to_load, models, st.session_state['comment'])
97
- # Modify results to sum the QuAL score and to ignore Q3 if Q2 no suggestion
98
- # if results['q2i']['label'] == 1:
99
- # results['q3i']['label'] = 1 # can't have connection if no suggestion
100
- # results['qual']['label'] = results['q1']['label'] + (not results['q2i']['label']) + (not results['q3i']['label'])
101
 
102
  overview = NQDOverview(st, results)
103
- overview.draw()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import pandas as pd
4
  from datetime import datetime
5
  from plotly import graph_objects as go
6
+ from transformers_interpret import SequenceClassificationExplainer
7
+ from annotated_text import annotated_text
8
+ from palettable.scientific.sequential import Devon_10_r
9
+ from palettable.colorbrewer.diverging import RdYlGn_10, PuOr_10, BrBG_10
10
  from overview import NQDOverview
11
 
12
  import torch
 
97
  st.experimental_rerun()
98
 
99
  results = run_models(models_to_load, models, st.session_state['comment'])
100
+ #Modify results to sum the QuAL score and to ignore Q3 if Q2 no suggestion
101
+ if results['q2i']['label'] == 1:
102
+ results['q3i']['label'] = 1 # can't have connection if no suggestion
103
+ results['qual']['label'] = results['q1']['label'] + (not results['q2i']['label']) + (not results['q3i']['label'])
104
 
105
  overview = NQDOverview(st, results)
106
+ overview.draw()
107
+
108
+ def rescale(x):
109
+ return (x + 1.0) / 2.0
110
+
111
+ def get_explained_words(comment, pipe, label, cmap):
112
+ cls_explainer = SequenceClassificationExplainer(
113
+ pipe.model,
114
+ pipe.tokenizer)
115
+ word_attributions = cls_explainer(comment, class_name=label)[1:-1]
116
+
117
+ # Get rid of "##"
118
+ to_disp = [
119
+ (word, '', f'rgba{tuple([int(c*255) for c in cmap.mpl_colormap(rescale(word_score))])}')
120
+ for word, word_score in word_attributions
121
+ ]
122
+ return to_disp
123
+
124
+ qual_map = {
125
+ 0: 'minimal',
126
+ 1: 'very low',
127
+ 2: 'low',
128
+ 3: 'average',
129
+ 4: 'above average',
130
+ 5: 'excellent'
131
+ }
132
+
133
+ q1_map = {
134
+ 0: "minimal",
135
+ 1: "low",
136
+ 2: "moderate",
137
+ 3: "high"
138
+ }
139
+
140
+ q2i_map = {
141
+ 0: "did",
142
+ 1: "did not"
143
+ }
144
+
145
+ with st.expander('Expand to explore further'):
146
+ st.write(f'Your comment was rated as a QuAL score of **{results["qual"]["label"]}**, indicating **{qual_map[results["qual"]["label"]]}** quality feedback.')
147
+
148
+ st.markdown('### Level of Detail')
149
+ st.write(f"The model identified a **{q1_map[results['q1']['label']]}** level of detail in your comment.")
150
+ st.write("Below are words that pointed the model toward (green) or against (red) identifying a high level of detail:")
151
+ annotated_text(get_explained_words(st.session_state['comment'], models['q1'], 'LABEL_3', RdYlGn_10))
152
+
153
+ st.markdown('### Suggestion for Improvement')
154
+ st.write(f"The model **{q2i_map[results['q2i']['label']]}** predict that you provided a suggestion for improvement in your comment.")
155
+ st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a suggestion for improvement:")
156
+ annotated_text(get_explained_words(st.session_state['comment'], models['q2i'], 'LABEL_0', RdYlGn_10))
157
+
158
+ if results['q2i']['label'] == 0:
159
+ st.markdown('### Suggestion Linking')
160
+ st.write(f"The model **{q2i_map[results['q3i']['label']]}** predict that you linked your suggestion")
161
+ st.write(f"Below are words that pointed the model toward (green) or against (red) identifying a linked suggestion:")
162
+ annotated_text(get_explained_words(st.session_state['comment'], models['q3i'], 'LABEL_0', RdYlGn_10))
163
+
164
+
165
+ # annotated_text(to_disp)
166
+ # cls_explainer = SequenceClassificationExplainer(
167
+ # models['q1'].model,
168
+ # models['q1'].tokenizer)
169
+ # word_attributions = cls_explainer(st.session_state['comment'], class_name='LABEL_3')[1:-1]
170
+ # to_disp = [
171
+ # (word, f'{word_score:.2f}', f'rgba{tuple([int(c*255) for c in Devon_10_r.mpl_colormap(word_score)])}')
172
+ # for word, word_score in word_attributions
173
+ # ]
174
+ # print(to_disp)
175
+ # annotated_text(to_disp)
requirements.txt CHANGED
@@ -3,9 +3,12 @@ torch
3
  torchvision
4
  torchaudio
5
  transformers
 
6
  plotly==5.11.0
7
  pandas
8
  spacy
9
  altair
10
  hydralit_components
11
- matplotlib
 
 
 
3
  torchvision
4
  torchaudio
5
  transformers
6
+ transformers-interpret
7
  plotly==5.11.0
8
  pandas
9
  spacy
10
  altair
11
  hydralit_components
12
+ matplotlib
13
+ st-annotated-text
14
+ palettable