Rob Caamano commited on
Commit
f7d1441
·
unverified ·
1 Parent(s): d4e5db4
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -28,30 +28,29 @@ if selected_model == "Fine-tuned Toxicity Model":
28
 
29
  def get_highest_toxicity_class(prediction):
30
  max_index = prediction.argmax()
31
- max_value = prediction[max_index]
32
- return [(model.config.id2label[i], pred_value) for i, pred_value in enumerate(prediction) if pred_value >= 0.5]
33
 
34
  input = tokenizer(text, return_tensors="tf")
35
  prediction = model(input)[0].numpy()[0]
36
 
37
  if st.button("Submit", type="primary"):
38
- labels_with_probabilities = get_highest_toxicity_class(prediction)
39
 
40
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
41
 
42
  if selected_model == "Fine-tuned Toxicity Model":
43
- column_name = "Toxicity Classes"
44
  else:
45
  column_name = "Prediction"
46
 
47
- if not labels_with_probabilities:
48
  st.write("This tweet is not toxic.")
49
- else:
50
- df = pd.DataFrame(
51
- {
52
- "Tweet (portion)": [tweet_portion] * len(labels_with_probabilities),
53
- column_name: [label for label, _ in labels_with_probabilities],
54
- "Probability": [probability for _, probability in labels_with_probabilities],
55
- }
56
- )
57
- st.table(df)
 
28
 
29
  def get_highest_toxicity_class(prediction):
30
  max_index = prediction.argmax()
31
+ return model.config.id2label[max_index], prediction[max_index]
 
32
 
33
  input = tokenizer(text, return_tensors="tf")
34
  prediction = model(input)[0].numpy()[0]
35
 
36
  if st.button("Submit", type="primary"):
37
+ label, probability = get_highest_toxicity_class(prediction)
38
 
39
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
40
 
41
  if selected_model == "Fine-tuned Toxicity Model":
42
+ column_name = "Highest Toxicity Class"
43
  else:
44
  column_name = "Prediction"
45
 
46
+ if probability < 0.1:
47
  st.write("This tweet is not toxic.")
48
+
49
+ df = pd.DataFrame(
50
+ {
51
+ "Tweet (portion)": [tweet_portion],
52
+ column_name: [label],
53
+ "Probability": [probability],
54
+ }
55
+ )
56
+ st.table(df)