File size: 3,954 Bytes
bd229ab
 
ad0b71a
59e622f
ad0b71a
bd229ab
 
 
 
ad0b71a
59e622f
bd229ab
 
ad0b71a
106e766
 
 
 
bd229ab
 
106e766
bd229ab
 
 
106e766
bd229ab
 
 
 
 
 
 
 
59e622f
 
d6b5249
 
59e622f
 
d6b5249
 
ad0b71a
59e622f
 
 
 
 
 
c91906e
 
106e766
 
59e622f
 
 
 
 
 
106e766
59e622f
106e766
59e622f
 
 
 
 
 
 
 
 
 
 
ad0b71a
c91906e
 
106e766
 
c91906e
59e622f
 
 
c91906e
ad0b71a
59e622f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# classifier.py
import torch
from model_loader import classifier_model
from paraphraser import paraphrase_comment
from metrics import compute_semantic_similarity, compute_empathy_score, compute_bleu_score, compute_rouge_score

def classify_toxic_comment(comment):
    """
    Classify a comment as toxic or non-toxic using the fine-tuned XLM-RoBERTa model.
    If toxic, paraphrase the comment, re-evaluate, and compute essential metrics.
    Returns the prediction label, confidence, color, toxicity score, bias score, paraphrased comment (if applicable), and its metrics.
    """
    if not comment.strip():
        return "Error: Please enter a comment.", None, None, None, None, None, None, None, None, None, None, None, None

    # Access the model and tokenizer
    model = classifier_model.model
    tokenizer = classifier_model.tokenizer

    # Tokenize the input comment
    inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get the predicted class (0 = non-toxic, 1 = toxic)
    predicted_class = torch.argmax(logits, dim=1).item()
    label = "Toxic" if predicted_class == 1 else "Non-Toxic"
    confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
    label_color = "red" if label == "Toxic" else "green"

    # Compute Toxicity Score (approximated as the probability of the toxic class)
    toxicity_score = torch.softmax(logits, dim=1)[0][1].item()
    toxicity_score = round(toxicity_score, 2)

    # Simulate Bias Score (placeholder)
    bias_score = 0.01 if label == "Non-Toxic" else 0.15
    bias_score = round(bias_score, 2)

    # If the comment is toxic, paraphrase it and compute essential metrics
    paraphrased_comment = None
    paraphrased_prediction = None
    paraphrased_confidence = None
    paraphrased_color = None
    paraphrased_toxicity_score = None
    paraphrased_bias_score = None
    semantic_similarity = None
    empathy_score = None
    bleu_score = None
    rouge_scores = None

    if label == "Toxic":
        # Paraphrase the comment
        paraphrased_comment = paraphrase_comment(comment)

        # Re-evaluate the paraphrased comment
        paraphrased_inputs = tokenizer(paraphrased_comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            paraphrased_outputs = model(**paraphrased_inputs)
            paraphrased_logits = paraphrased_outputs.logits

        paraphrased_predicted_class = torch.argmax(paraphrased_logits, dim=1).item()
        paraphrased_label = "Toxic" if paraphrased_predicted_class == 1 else "Non-Toxic"
        paraphrased_confidence = torch.softmax(paraphrased_logits, dim=1)[0][paraphrased_predicted_class].item()
        paraphrased_color = "red" if paraphrased_label == "Toxic" else "green"
        paraphrased_toxicity_score = torch.softmax(paraphrased_logits, dim=1)[0][1].item()
        paraphrased_toxicity_score = round(paraphrased_toxicity_score, 2)
        paraphrased_bias_score = 0.01 if paraphrased_label == "Non-Toxic" else 0.15  # Placeholder
        paraphrased_bias_score = round(paraphrased_bias_score, 2)

        # Compute essential metrics
        semantic_similarity = compute_semantic_similarity(comment, paraphrased_comment)
        empathy_score = compute_empathy_score(paraphrased_comment)
        bleu_score = compute_bleu_score(comment, paraphrased_comment)
        rouge_scores = compute_rouge_score(comment, paraphrased_comment)

    return (
        f"Prediction: {label}", confidence, label_color, toxicity_score, bias_score,
        paraphrased_comment, f"Prediction: {paraphrased_label}" if paraphrased_comment else None,
        paraphrased_confidence, paraphrased_color, paraphrased_toxicity_score, paraphrased_bias_score,
        semantic_similarity, empathy_score, bleu_score, rouge_scores
    )