File size: 4,369 Bytes
ca75f71
 
 
 
 
 
bd229ab
ca75f71
bd229ab
ca75f71
 
 
bd229ab
ca75f71
 
59e622f
ca75f71
 
c586725
ca75f71
 
 
c91906e
ca75f71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# classifier.py
import torch
import time
from model_loader import classifier_model
from paraphraser import paraphrase_comment
from metrics import compute_semantic_similarity, compute_empathy_score

def classify_toxic_comment(comment):
    """
    Classify a comment as toxic or non-toxic using the fine-tuned XLM-RoBERTa model.
    If toxic, paraphrase the comment, re-evaluate, and compute essential metrics.
    Returns the prediction label, confidence, color, toxicity score, bias score, paraphrased comment (if applicable), and its metrics.
    """
    start_total = time.time()
    print("Starting classification...")

    if not comment.strip():
        return "Error: Please enter a comment.", None, None, None, None, None, None, None, None, None, None

    # Access the model and tokenizer
    model = classifier_model.model
    tokenizer = classifier_model.tokenizer

    # Tokenize the input comment
    start_classification = time.time()
    inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get the predicted class (0 = non-toxic, 1 = toxic)
    predicted_class = torch.argmax(logits, dim=1).item()
    label = "Toxic" if predicted_class == 1 else "Non-Toxic"
    confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
    label_color = "red" if label == "Toxic" else "green"

    # Compute Toxicity Score (approximated as the probability of the toxic class)
    toxicity_score = torch.softmax(logits, dim=1)[0][1].item()
    toxicity_score = round(toxicity_score, 2)

    # Simulate Bias Score (placeholder)
    bias_score = 0.01 if label == "Non-Toxic" else 0.15
    bias_score = round(bias_score, 2)
    print(f"Classification took {time.time() - start_classification:.2f} seconds")

    # If the comment is toxic, paraphrase it and compute essential metrics
    paraphrased_comment = None
    paraphrased_prediction = None
    paraphrased_confidence = None
    paraphrased_color = None
    paraphrased_toxicity_score = None
    paraphrased_bias_score = None
    semantic_similarity = None
    empathy_score = None

    if label == "Toxic":
        # Paraphrase the comment
        start_paraphrase = time.time()
        paraphrased_comment = paraphrase_comment(comment)
        print(f"Paraphrasing took {time.time() - start_paraphrase:.2f} seconds")

        # Re-evaluate the paraphrased comment
        start_reclassification = time.time()
        paraphrased_inputs = tokenizer(paraphrased_comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            paraphrased_outputs = model(**paraphrased_inputs)
            paraphrased_logits = paraphrased_outputs.logits

        paraphrased_predicted_class = torch.argmax(paraphrased_logits, dim=1).item()
        paraphrased_label = "Toxic" if paraphrased_predicted_class == 1 else "Non-Toxic"
        paraphrased_confidence = torch.softmax(paraphrased_logits, dim=1)[0][paraphrased_predicted_class].item()
        paraphrased_color = "red" if paraphrased_label == "Toxic" else "green"
        paraphrased_toxicity_score = torch.softmax(paraphrased_logits, dim=1)[0][1].item()
        paraphrased_toxicity_score = round(paraphrased_toxicity_score, 2)
        paraphrased_bias_score = 0.01 if paraphrased_label == "Non-Toxic" else 0.15  # Placeholder
        paraphrased_bias_score = round(paraphrased_bias_score, 2)
        print(f"Reclassification of paraphrased comment took {time.time() - start_reclassification:.2f} seconds")

        # Compute essential metrics
        start_metrics = time.time()
        semantic_similarity = compute_semantic_similarity(comment, paraphrased_comment)
        empathy_score = compute_empathy_score(paraphrased_comment)
        print(f"Metrics computation took {time.time() - start_metrics:.2f} seconds")

    print(f"Total processing time: {time.time() - start_total:.2f} seconds")

    return (
        f"Prediction: {label}", confidence, label_color, toxicity_score, bias_score,
        paraphrased_comment, f"Prediction: {paraphrased_label}" if paraphrased_comment else None,
        paraphrased_confidence, paraphrased_color, paraphrased_toxicity_score, paraphrased_bias_score,
        semantic_similarity, empathy_score
    )