maxspad commited on
Commit
f4cb8a7
·
1 Parent(s): 4459c83

starting refactor

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -1,18 +1,21 @@
1
  import streamlit as st
2
  import transformers as tf
3
 
 
4
  @st.experimental_singleton(show_spinner=False)
5
  def load_model(username, prefix, model_name):
6
  p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}')
7
  return p
8
 
 
9
  USERNAME = 'maxspad'
10
  PREFIX = 'nlp-qual'
11
  models_to_load = ['qual', 'q1', 'q2i', 'q3i']
12
  n_models = float(len(models_to_load))
13
-
14
  models = {}
15
 
 
 
16
  lc_placeholder = st.empty()
17
  loader_container = lc_placeholder.container()
18
  loader_container.caption('Loading models... please wait...')
@@ -22,13 +25,13 @@ for i, mn in enumerate(models_to_load):
22
  models[mn] = load_model(USERNAME, PREFIX, mn)
23
  lc_placeholder.empty()
24
 
25
- text = st.text_area('Type your stuff')
26
 
27
  denoms = ['5','3']
28
  for mn in models_to_load:
29
  st.header(mn)
30
  cols = st.columns(2)
31
- res = models[mn](text)[0]
32
 
33
  if mn == 'qual':
34
  cols[0].metric('Score', f"{res['label'].split('_')[1]}/5")
 
1
  import streamlit as st
2
  import transformers as tf
3
 
4
+ # Function to load and cache models
5
  @st.experimental_singleton(show_spinner=False)
6
  def load_model(username, prefix, model_name):
7
  p = tf.pipeline('text-classification', f'{username}/{prefix}-{model_name}')
8
  return p
9
 
10
+ # Specify which models to load
11
  USERNAME = 'maxspad'
12
  PREFIX = 'nlp-qual'
13
  models_to_load = ['qual', 'q1', 'q2i', 'q3i']
14
  n_models = float(len(models_to_load))
 
15
  models = {}
16
 
17
+ # Show a progress bar while models are downloading,
18
+ # then hide it when done
19
  lc_placeholder = st.empty()
20
  loader_container = lc_placeholder.container()
21
  loader_container.caption('Loading models... please wait...')
 
25
  models[mn] = load_model(USERNAME, PREFIX, mn)
26
  lc_placeholder.empty()
27
 
28
+ comment = st.text_area('Try a comment:')
29
 
30
  denoms = ['5','3']
31
  for mn in models_to_load:
32
  st.header(mn)
33
  cols = st.columns(2)
34
+ res = models[mn](comment)[0]
35
 
36
  if mn == 'qual':
37
  cols[0].metric('Score', f"{res['label'].split('_')[1]}/5")