Rob Caamano commited on
Commit
7b977eb
·
unverified ·
1 Parent(s): b081db9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -21
app.py CHANGED
@@ -26,32 +26,25 @@ if selected_model in ["Fine-tuned Toxicity Model"]:
26
  toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
27
  model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}
28
 
29
- def get_toxicity_class(prediction):
30
- max_index = prediction.argmax()
31
- return model.config.id2label[max_index], prediction[max_index]
32
 
33
  input = tokenizer(text, return_tensors="tf")
34
  prediction = model(input)[0].numpy()[0]
35
 
36
  if st.button("Submit", type="primary"):
37
- label, probability = get_toxicity_class(prediction)
38
-
39
- tweet_portion = text[:50] + "..." if len(text) > 50 else text
40
 
41
- if selected_model in ["Fine-tuned Toxicity Model"]:
42
- column_name = "Toxicity Class"
43
- else:
44
- column_name = "Prediction"
45
 
46
- if probability < 0.1:
47
  st.write("This text is not toxic.")
48
-
49
- df = pd.DataFrame(
50
- {
51
- "Text (portion)": [tweet_portion],
52
- column_name: [label],
53
- "Probability": [probability],
54
- }
55
- )
56
-
57
- st.table(df)
 
26
  toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
27
  model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}
28
 
29
+ def get_toxicity_class(predictions, threshold=0.5):
30
+ return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
 
31
 
32
  input = tokenizer(text, return_tensors="tf")
33
  prediction = model(input)[0].numpy()[0]
34
 
35
  if st.button("Submit", type="primary"):
36
+ toxic_labels = get_toxicity_class(prediction)
 
 
37
 
38
+ tweet_portion = text[:50] + "..." if len(text) > 50 else text
 
 
 
39
 
40
+ if len(toxic_labels) == 0:
41
  st.write("This text is not toxic.")
42
+ else:
43
+ df = pd.DataFrame(
44
+ {
45
+ "Text (portion)": [tweet_portion] * len(toxic_labels),
46
+ "Toxicity Class": list(toxic_labels.keys()),
47
+ "Probability": list(toxic_labels.values()),
48
+ }
49
+ )
50
+ st.table(df)