Rob Caamano commited on
Commit
d4e5db4
·
unverified ·
1 Parent(s): 84bfc73
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -28,29 +28,30 @@ if selected_model == "Fine-tuned Toxicity Model":
28
 
29
  def get_highest_toxicity_class(prediction):
30
  max_index = prediction.argmax()
31
- return model.config.id2label[max_index], prediction[max_index]
 
32
 
33
  input = tokenizer(text, return_tensors="tf")
34
  prediction = model(input)[0].numpy()[0]
35
 
36
  if st.button("Submit", type="primary"):
37
- label, probability = get_highest_toxicity_class(prediction)
38
 
39
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
40
 
41
  if selected_model == "Fine-tuned Toxicity Model":
42
- column_name = "Highest Toxicity Class"
43
  else:
44
  column_name = "Prediction"
45
 
46
- if probability < 0.1:
47
  st.write("This tweet is not toxic.")
48
-
49
- df = pd.DataFrame(
50
- {
51
- "Tweet (portion)": [tweet_portion],
52
- column_name: [label],
53
- "Probability": [probability],
54
- }
55
- )
56
- st.table(df)
 
28
 
29
  def get_highest_toxicity_class(prediction):
30
  max_index = prediction.argmax()
31
+ max_value = prediction[max_index]
32
+ return [(model.config.id2label[i], pred_value) for i, pred_value in enumerate(prediction) if pred_value >= 0.5]
33
 
34
  input = tokenizer(text, return_tensors="tf")
35
  prediction = model(input)[0].numpy()[0]
36
 
37
  if st.button("Submit", type="primary"):
38
+ labels_with_probabilities = get_highest_toxicity_class(prediction)
39
 
40
  tweet_portion = text[:50] + "..." if len(text) > 50 else text
41
 
42
  if selected_model == "Fine-tuned Toxicity Model":
43
+ column_name = "Toxicity Classes"
44
  else:
45
  column_name = "Prediction"
46
 
47
+ if not labels_with_probabilities:
48
  st.write("This tweet is not toxic.")
49
+ else:
50
+ df = pd.DataFrame(
51
+ {
52
+ "Tweet (portion)": [tweet_portion] * len(labels_with_probabilities),
53
+ column_name: [label for label, _ in labels_with_probabilities],
54
+ "Probability": [probability for _, probability in labels_with_probabilities],
55
+ }
56
+ )
57
+ st.table(df)