Rob Caamano commited on
Commit
682174e
·
unverified ·
1 Parent(s): 04fd2b6
Files changed (1) hide show
  1. app.py +39 -34
app.py CHANGED
@@ -1,56 +1,61 @@
1
  import streamlit as st
2
- import pandas as pd
3
- from transformers import AutoTokenizer, pipeline
4
  from transformers import (
5
  TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
6
  )
 
7
 
8
- st.title("Detecting Toxic Tweets")
9
 
10
  demo = """Your words are like poison. They seep into my mind and make me feel worthless."""
11
 
12
- text = st.text_area("Input Text", demo, height=250)
 
 
 
13
 
14
- model_options = {
15
- "DistilBERT Base Uncased (SST-2)": "distilbert-base-uncased-finetuned-sst-2-english",
16
- "Fine-tuned Toxicity Model": "RobCaamano/toxicity",
17
- }
18
- selected_model = st.selectbox("Select Model", options=list(model_options.keys()))
 
19
 
20
- mod_name = model_options[selected_model]
21
-
22
- tokenizer = AutoTokenizer.from_pretrained(mod_name)
23
- model = AutoModelForSequenceClassification.from_pretrained(mod_name)
24
  clf = pipeline(
25
  "sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
26
  )
27
 
28
- if selected_model in ["Fine-tuned Toxicity Model"]:
29
- toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
30
- model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}
 
 
 
 
 
 
31
 
32
- def get_most_toxic_class(predictions):
33
- return {model.config.id2label[i]: pred for i, pred in enumerate(predictions)}
34
 
35
  input = tokenizer(text, return_tensors="tf")
36
 
37
- if st.button("Submit", type="primary"):
38
  results = dict(d.values() for d in clf(text)[0])
39
- toxic_labels = get_most_toxic_class(results)
 
 
 
 
 
40
 
41
- tweet_portion = text[:50] + "..." if len(text) > 50 else text
 
42
 
43
- if len(toxic_labels) == 0:
44
- st.write("This text is not toxic.")
45
  else:
46
- max_toxic_class = max(toxic_labels, key=toxic_labels.get)
47
- max_probability = toxic_labels[max_toxic_class]
48
-
49
- df = pd.DataFrame(
50
- {
51
- "Text (portion)": [tweet_portion],
52
- "Toxicity Class": [max_toxic_class],
53
- "Probability": [max_probability],
54
- }
55
- )
56
- st.table(df)
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer
 
3
  from transformers import (
4
  TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
5
  )
6
+ from transformers import pipeline
7
 
8
+ st.title("Toxic Tweet Classifier")
9
 
10
  demo = """Your words are like poison. They seep into my mind and make me feel worthless."""
11
 
12
+ text = ""
13
+ submit = False
14
+ model_name = ""
15
+ col1, col2, col3 = st.columns([2,1,1])
16
 
17
+ with st.container():
18
+ model_name = st.selectbox(
19
+ "Select the model you want to use below.",
20
+ ("RobCaamano/toxicity",),
21
+ )
22
+ submit = st.button("Submit", type="primary", use_container_width=True)
23
 
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
 
26
  clf = pipeline(
27
  "sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
28
  )
29
 
30
+ with col1:
31
+ st.subheader("Tweet")
32
+ text = st.text_area("Input text", demo, height=275)
33
+
34
+ with col2:
35
+ st.subheader("Classification")
36
+
37
+ with col3:
38
+ st.subheader("Probability")
39
 
 
 
40
 
41
  input = tokenizer(text, return_tensors="tf")
42
 
43
+ if submit:
44
  results = dict(d.values() for d in clf(text)[0])
45
+ classes = {k: results[k] for k in results.keys() if not k == "toxic"}
46
+
47
+ max_class = max(classes, key=classes.get)
48
+
49
+ with col2:
50
+ st.write(f"#### {max_class}")
51
 
52
+ with col3:
53
+ st.write(f"#### **{classes[max_class]:.2f}%**")
54
 
55
+ if results["toxic"] < 0.5:
56
+ st.success("This tweet is unlikely to be be toxic!", icon=":white_check_mark:")
57
  else:
58
+ st.warning('This tweet is likely to be toxic.', icon=":warning:")
59
+
60
+ expander = st.expander("Raw output")
61
+ expander.write(results)