File size: 1,902 Bytes
fd77815
528da04
03f4672
fd77815
 
 
29406f8
fd77815
 
c518343
fd77815
 
 
528da04
 
2db0396
528da04
 
 
 
fd77815
390d16b
 
4a0592e
 
2d942ee
 
4a0592e
c518343
 
 
fd77815
 
c518343
fd77815
 
c518343
395b5b7
2d942ee
 
395b5b7
2d942ee
395b5b7
2d942ee
528da04
1167df7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer
from transformers import (
    TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
)

st.title("Detecting Toxic Tweets")

demo = """I'm so proud of myself for accomplishing my goals today. #motivation #success"""

text = st.text_area("Input text", demo, height=250)

model_options = {
    "DistilBERT Base Uncased (SST-2)": "distilbert-base-uncased-finetuned-sst-2-english",
    "Fine-tuned Toxicity Model": "RobCaamano/toxicity_distilbert",
}
selected_model = st.selectbox("Select Model", options=list(model_options.keys()))

mod_name = model_options[selected_model]

tokenizer = AutoTokenizer.from_pretrained(mod_name)
model = AutoModelForSequenceClassification.from_pretrained(mod_name)

if selected_model == "Fine-tuned Toxicity Model":
    toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
    model.config.id2label = {i: toxicity_classes[i] for i in range(model.config.num_labels)}

def get_highest_toxicity_class(prediction):
    max_index = prediction.argmax()
    return model.config.id2label[max_index], prediction[max_index]

input = tokenizer(text, return_tensors="tf")
prediction = model(input)[0].numpy()[0]

if st.button("Submit", type="primary"):
    label, probability = get_highest_toxicity_class(prediction)
    
    tweet_portion = text[:50] + "..." if len(text) > 50 else text

    if selected_model == "Fine-tuned Toxicity Model":
        column_name = "Highest Toxicity Class"
    else:
        column_name = "Prediction"
    
    if probability < 0.1:
        st.write("This tweet is not toxic.")
    else:
        df = pd.DataFrame(
            {
                "Tweet (portion)": [tweet_portion],
                column_name: [label],
                "Probability": [probability],
            }
        )
        st.table(df)