Rob Caamano commited on
Commit
c518343
·
unverified ·
1 Parent(s): 2d942ee
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -4,11 +4,10 @@ from transformers import AutoTokenizer
4
  from transformers import (
5
  TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
6
  )
7
- from transformers import pipeline
8
 
9
  st.title("Detecting Toxic Tweets")
10
 
11
- demo = """Your words are like poison. They seep into my mind and make me feel worthless."""
12
 
13
  text = st.text_area("Input text", demo, height=250)
14
 
@@ -29,15 +28,15 @@ if selected_model == "Fine-tuned Toxicity Model":
29
  toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
30
  model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}
31
 
32
- clf = pipeline(
33
- "text-classification", model=model, tokenizer=tokenizer, return_all_scores=True
34
- )
35
 
36
  input = tokenizer(text, return_tensors="tf")
 
37
 
38
  if st.button("Submit", type="primary"):
39
- results = clf(text)[0]
40
- max_class = max(results, key=lambda x: x["score"])
41
 
42
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
43
 
@@ -50,8 +49,8 @@ if st.button("Submit", type="primary"):
50
  df = pd.DataFrame(
51
  {
52
  "Tweet (portion)": [tweet_portion],
53
- column_name: [max_class["label"]],
54
- "Probability": [max_class["score"]],
55
  }
56
  )
57
  st.table(df)
 
4
  from transformers import (
5
  TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
6
  )
 
7
 
8
  st.title("Detecting Toxic Tweets")
9
 
10
+ demo = """I'm so proud of myself for accomplishing my goals today. #motivation #success"""
11
 
12
  text = st.text_area("Input text", demo, height=250)
13
 
 
28
  toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
29
  model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}
30
 
31
+ def get_highest_toxicity_class(prediction):
32
+ max_index = prediction.argmax()
33
+ return model.config.id2label[max_index], prediction[max_index]
34
 
35
  input = tokenizer(text, return_tensors="tf")
36
+ prediction = model(input)[0].numpy()[0]
37
 
38
  if st.button("Submit", type="primary"):
39
+ label, probability = get_highest_toxicity_class(prediction)
 
40
 
41
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
42
 
 
49
  df = pd.DataFrame(
50
  {
51
  "Tweet (portion)": [tweet_portion],
52
+ column_name: [label],
53
+ "Probability": [probability],
54
  }
55
  )
56
  st.table(df)