maxspad commited on
Commit
1dde253
Β·
1 Parent(s): f4cb8a7

basic layout and function

Browse files
Files changed (2) hide show
  1. app.py +124 -25
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import streamlit as st
2
  import transformers as tf
 
 
3
 
4
  # Function to load and cache models
5
  @st.experimental_singleton(show_spinner=False)
@@ -7,6 +9,27 @@ def load_model(username, prefix, model_name):
7
  p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}')
8
  return p
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Specify which models to load
11
  USERNAME = 'maxspad'
12
  PREFIX = 'nlp-qual'
@@ -25,29 +48,105 @@ for i, mn in enumerate(models_to_load):
25
  models[mn] = load_model(USERNAME, PREFIX, mn)
26
  lc_placeholder.empty()
27
 
28
- comment = st.text_area('Try a comment:')
29
-
30
- denoms = ['5','3']
31
- for mn in models_to_load:
32
- st.header(mn)
33
- cols = st.columns(2)
34
- res = models[mn](comment)[0]
35
-
36
- if mn == 'qual':
37
- cols[0].metric('Score', f"{res['label'].split('_')[1]}/5")
38
- elif mn == 'q1':
39
- cols[0].metric('Score', f"{res['label'].split('_')[1]}/3")
40
- elif mn == 'q2i':
41
- if res['label'] == 'LABEL_0':
42
- cols[0].metric('Suggestion for improvement?', 'Yes')
43
- else:
44
- cols[0].metric('Suggestion for improvement?', 'No')
45
- elif mn == 'q3i':
46
- if res['label'] == 'LABEL_0':
47
- cols[0].metric('Suggestion linked?', 'Yes')
48
- else:
49
- cols[0].metric('Suggestion linked?', 'No')
50
-
51
- cols[1].caption('Confidence')
52
- cols[1].progress(res['score'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
1
  import streamlit as st
2
  import transformers as tf
3
+ import plotly.graph_objects as go
4
+ import matplotlib.cm as cm
5
 
6
  # Function to load and cache models
7
  @st.experimental_singleton(show_spinner=False)
 
9
  p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}')
10
  return p
11
 
12
+ def get_results(model, c):
13
+ res = model(c)[0]
14
+ label = float(res['label'].split('_')[1])
15
+ score = res['score']
16
+ return {'label': label, 'score': score}
17
+
18
+ def run_models(model_names, models, c):
19
+ results = {}
20
+ for mn in model_names:
21
+ results[mn] = get_results(models[mn], c)
22
+ return results
23
+
24
+
25
+ st.title('How *great* is your feedback?')
26
+ st.markdown(
27
+ """Medical education *requires* high-quality feedback, but evaluating feedback
28
+ is difficult and time-consuming. This tool uses NLP/ML to predict a validated
29
+ feedback quality metric known as the QuAL Score. *Try it for yourself!*
30
+ """)
31
+
32
+ ### Load models
33
  # Specify which models to load
34
  USERNAME = 'maxspad'
35
  PREFIX = 'nlp-qual'
 
48
  models[mn] = load_model(USERNAME, PREFIX, mn)
49
  lc_placeholder.empty()
50
 
51
+
52
+ ### Process input
53
+ with st.form('comment_form'):
54
+ comment = st.text_area('Try a comment:')
55
+ left_col, right_col = st.columns([1,9], gap='medium')
56
+ submitted = left_col.form_submit_button('Submit')
57
+ try_example = right_col.form_submit_button('Try an example!')
58
+
59
+ results = run_models(models_to_load, models, comment)
60
+
61
+ tab_titles = ['Overview', 'Q1 - Level of Detail', 'Q2 - Suggestion Given', 'Q3 - Suggestion Linked']
62
+ tabs = st.tabs(tab_titles)
63
+
64
+
65
+ with tabs[0]:
66
+ with st.expander('What is the QuAL score?'):
67
+ st.markdown('**The best thing ever**!')
68
+ cmap = cm.get_cmap('RdYlGn')
69
+ color = cmap(results['qual']['label'] / 6.0)
70
+ color = f'rgba({int(color[0]*256)}, {int(color[1]*256)}, {int(color[2]*256)}, {int(color[3]*256)})'
71
+
72
+ fig = go.Figure(go.Indicator(
73
+ domain = {'x': [0, 1], 'y': [0, 1]},
74
+ value = results['qual']['label'],
75
+ mode = "gauge+number",
76
+ title = {'text': "QuAL"},
77
+ # delta = {'reference': 380},
78
+ gauge = {'axis': {'range': [None, 5]},
79
+ 'bgcolor': 'lightgray',
80
+ # 'steps': [
81
+ # {'range': [0,1], 'color': "rgb(215,48,39)"},
82
+ # {'range': [1,2], 'color': "rgb(244,109,67)"},
83
+ # {'range': [2,3], 'color': "rgb(254,224,139)"},
84
+ # {'range': [3,4], 'color': "rgb(102,189,99)"},
85
+ # {'range': [4,5], 'color': "rgb(0,104,55)"}
86
+ # ],
87
+ 'bar': {'color': color, 'thickness': 1.0},
88
+
89
+ }
90
+ ), layout=go.Layout(width=750, height=375))# layout={'paper_bgcolor': 'rgb(245,245,245)'})#,
91
+
92
+ st.plotly_chart(fig)
93
+
94
+ cols = st.columns(3)
95
+ cols[0].markdown('#### Level of Detail')
96
+ q1lab = results['q1']['label']
97
+ if q1lab == 0:
98
+ md_str = '# πŸ˜₯ None'
99
+ elif q1lab == 1:
100
+ md_str = '# 😐 Low'
101
+ elif q1lab == 2:
102
+ md_str = '# 😊 Medium'
103
+ elif q1lab == 3:
104
+ md_str = '# 😁 High'
105
+ cols[0].markdown(md_str)
106
+
107
+ cols[1].markdown('#### Suggestion Given')
108
+ q2lab = results['q2i']['label']
109
+ if q2lab == 0:
110
+ md_str = '# βœ… Yes'
111
+ else:
112
+ md_str = '# ❌ No'
113
+ cols[1].markdown(md_str)
114
+ # cols[1].markdown('# βœ… Yes', unsafe_allow_html=True)
115
+
116
+ cols[2].markdown('#### Suggestion Linked')
117
+ q3lab = results['q3i']['label']
118
+ if q3lab == 0:
119
+ md_str = '# βœ… Yes'
120
+ else:
121
+ md_str = '# ❌ No'
122
+ cols[2].markdown(md_str)
123
+
124
+ with tabs[1]:
125
+ st.write('hello')
126
+
127
+
128
+
129
+ # denoms = ['5','3']
130
+ # for mn in models_to_load:
131
+ # st.header(mn)
132
+ # cols = st.columns(2)
133
+ # res = models[mn](comment)[0]
134
+
135
+ # if mn == 'qual':
136
+ # cols[0].metric('Score', f"{res['label'].split('_')[1]}/5")
137
+ # elif mn == 'q1':
138
+ # cols[0].metric('Score', f"{res['label'].split('_')[1]}/3")
139
+ # elif mn == 'q2i':
140
+ # if res['label'] == 'LABEL_0':
141
+ # cols[0].metric('Suggestion for improvement?', 'Yes')
142
+ # else:
143
+ # cols[0].metric('Suggestion for improvement?', 'No')
144
+ # elif mn == 'q3i':
145
+ # if res['label'] == 'LABEL_0':
146
+ # cols[0].metric('Suggestion linked?', 'Yes')
147
+ # else:
148
+ # cols[0].metric('Suggestion linked?', 'No')
149
+
150
+ # cols[1].caption('Confidence')
151
+ # cols[1].progress(res['score'])
152
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ torch
2
  torchvision
3
  torchaudio
4
  transformers
 
 
2
  torchvision
3
  torchaudio
4
  transformers
5
+ plotly==5.11.0