Spaces:
Runtime error
Runtime error
import streamlit as st | |
import transformers as tf | |
import plotly.graph_objects as go | |
import matplotlib.cm as cm | |
import pandas as pd | |
# Function to load and cache models | |
def load_model(username, prefix, model_name): | |
p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}') | |
return p | |
def load_pickle(f): | |
return pd.read_pickle(f) | |
def get_results(model, c): | |
res = model(c)[0] | |
label = float(res['label'].split('_')[1]) | |
score = res['score'] | |
return {'label': label, 'score': score} | |
def run_models(model_names, models, c): | |
results = {} | |
for mn in model_names: | |
results[mn] = get_results(models[mn], c) | |
return results | |
st.title('How *great* is your feedback?') | |
st.markdown( | |
"""Medical education *requires* high-quality feedback, but evaluating feedback | |
is difficult and time-consuming. This tool uses NLP/ML to predict a validated | |
feedback quality metric known as the QuAL Score. *Try it for yourself!* | |
""") | |
### Load models | |
# Specify which models to load | |
USERNAME = 'maxspad' | |
PREFIX = 'nlp-qual' | |
models_to_load = ['qual', 'q1', 'q2i', 'q3i'] | |
n_models = float(len(models_to_load)) | |
models = {} | |
# Show a progress bar while models are downloading, | |
# then hide it when done | |
lc_placeholder = st.empty() | |
loader_container = lc_placeholder.container() | |
loader_container.caption('Loading models... please wait...') | |
pbar = loader_container.progress(0.0) | |
for i, mn in enumerate(models_to_load): | |
pbar.progress((i+1.0) / n_models) | |
models[mn] = load_model(USERNAME, PREFIX, mn) | |
lc_placeholder.empty() | |
### Load example data | |
examples = load_pickle('test.pkl') | |
### Process input | |
ex = examples['comment'].sample(1).tolist()[0] | |
if 'comment' not in st.session_state: | |
st.session_state['comment'] = ex | |
with st.form('comment_form'): | |
comment = st.text_area('Try a comment:', value=st.session_state['comment']) | |
left_col, right_col = st.columns([1,9], gap='medium') | |
submitted = left_col.form_submit_button('Submit') | |
trying_example = right_col.form_submit_button('Try an example!') | |
if submitted: | |
st.session_state['button_clicked'] = 'submit' | |
st.session_state['comment'] = comment | |
st.experimental_rerun() | |
elif trying_example: | |
st.session_state['button_clicked'] = 'example' | |
st.session_state['comment'] = ex | |
st.experimental_rerun() | |
results = run_models(models_to_load, models, st.session_state['comment']) | |
tab_titles = ['Overview', 'Q1 - Level of Detail', 'Q2 - Suggestion Given', 'Q3 - Suggestion Linked'] | |
tabs = st.tabs(tab_titles) | |
with tabs[0]: | |
with st.expander('About the QuAL score?'): | |
st.markdown('**The best thing ever**!') | |
cmap = cm.get_cmap('RdYlGn') | |
color = cmap(results['qual']['label'] / 6.0) | |
color = f'rgba({int(color[0]*256)}, {int(color[1]*256)}, {int(color[2]*256)}, {int(color[3]*256)})' | |
fig = go.Figure(go.Indicator( | |
domain = {'x': [0, 1], 'y': [0, 1]}, | |
value = results['qual']['label'], | |
mode = "gauge+number", | |
title = {'text': "QuAL"}, | |
gauge = {'axis': {'range': [None, 5]}, | |
'bgcolor': 'lightgray', | |
'bar': {'color': color, 'thickness': 1.0}, | |
} | |
), layout=go.Layout(width=750, height=375))# layout={'paper_bgcolor': 'rgb(245,245,245)'})#, | |
st.plotly_chart(fig) | |
cols = st.columns(3) | |
cols[0].markdown('#### Level of Detail') | |
q1lab = results['q1']['label'] | |
if q1lab == 0: | |
md_str = '# π₯ None' | |
elif q1lab == 1: | |
md_str = '# π Low' | |
elif q1lab == 2: | |
md_str = '# π Medium' | |
elif q1lab == 3: | |
md_str = '# π High' | |
cols[0].markdown(md_str) | |
cols[1].markdown('#### Suggestion Given') | |
q2lab = results['q2i']['label'] | |
if q2lab == 0: | |
md_str = '# β Yes' | |
else: | |
md_str = '# β No' | |
cols[1].markdown(md_str) | |
# cols[1].markdown('# β Yes', unsafe_allow_html=True) | |
cols[2].markdown('#### Suggestion Linked') | |
q3lab = results['q3i']['label'] | |
if q3lab == 0: | |
md_str = '# β Yes' | |
else: | |
md_str = '# β No' | |
cols[2].markdown(md_str) | |
with tabs[1]: | |
st.write('hello') | |
# denoms = ['5','3'] | |
# for mn in models_to_load: | |
# st.header(mn) | |
# cols = st.columns(2) | |
# res = models[mn](comment)[0] | |
# if mn == 'qual': | |
# cols[0].metric('Score', f"{res['label'].split('_')[1]}/5") | |
# elif mn == 'q1': | |
# cols[0].metric('Score', f"{res['label'].split('_')[1]}/3") | |
# elif mn == 'q2i': | |
# if res['label'] == 'LABEL_0': | |
# cols[0].metric('Suggestion for improvement?', 'Yes') | |
# else: | |
# cols[0].metric('Suggestion for improvement?', 'No') | |
# elif mn == 'q3i': | |
# if res['label'] == 'LABEL_0': | |
# cols[0].metric('Suggestion linked?', 'Yes') | |
# else: | |
# cols[0].metric('Suggestion linked?', 'No') | |
# cols[1].caption('Confidence') | |
# cols[1].progress(res['score']) | |