Spaces:
Sleeping
Sleeping
File size: 3,954 Bytes
bd229ab ad0b71a 59e622f ad0b71a bd229ab ad0b71a 59e622f bd229ab ad0b71a 106e766 bd229ab 106e766 bd229ab 106e766 bd229ab 59e622f d6b5249 59e622f d6b5249 ad0b71a 59e622f c91906e 106e766 59e622f 106e766 59e622f 106e766 59e622f ad0b71a c91906e 106e766 c91906e 59e622f c91906e ad0b71a 59e622f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# classifier.py
import torch
from model_loader import classifier_model
from paraphraser import paraphrase_comment
from metrics import compute_semantic_similarity, compute_empathy_score, compute_bleu_score, compute_rouge_score
def classify_toxic_comment(comment):
"""
Classify a comment as toxic or non-toxic using the fine-tuned XLM-RoBERTa model.
If toxic, paraphrase the comment, re-evaluate, and compute essential metrics.
Returns the prediction label, confidence, color, toxicity score, bias score, paraphrased comment (if applicable), and its metrics.
"""
if not comment.strip():
return "Error: Please enter a comment.", None, None, None, None, None, None, None, None, None, None, None, None
# Access the model and tokenizer
model = classifier_model.model
tokenizer = classifier_model.tokenizer
# Tokenize the input comment
inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get the predicted class (0 = non-toxic, 1 = toxic)
predicted_class = torch.argmax(logits, dim=1).item()
label = "Toxic" if predicted_class == 1 else "Non-Toxic"
confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
label_color = "red" if label == "Toxic" else "green"
# Compute Toxicity Score (approximated as the probability of the toxic class)
toxicity_score = torch.softmax(logits, dim=1)[0][1].item()
toxicity_score = round(toxicity_score, 2)
# Simulate Bias Score (placeholder)
bias_score = 0.01 if label == "Non-Toxic" else 0.15
bias_score = round(bias_score, 2)
# If the comment is toxic, paraphrase it and compute essential metrics
paraphrased_comment = None
paraphrased_prediction = None
paraphrased_confidence = None
paraphrased_color = None
paraphrased_toxicity_score = None
paraphrased_bias_score = None
semantic_similarity = None
empathy_score = None
bleu_score = None
rouge_scores = None
if label == "Toxic":
# Paraphrase the comment
paraphrased_comment = paraphrase_comment(comment)
# Re-evaluate the paraphrased comment
paraphrased_inputs = tokenizer(paraphrased_comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
paraphrased_outputs = model(**paraphrased_inputs)
paraphrased_logits = paraphrased_outputs.logits
paraphrased_predicted_class = torch.argmax(paraphrased_logits, dim=1).item()
paraphrased_label = "Toxic" if paraphrased_predicted_class == 1 else "Non-Toxic"
paraphrased_confidence = torch.softmax(paraphrased_logits, dim=1)[0][paraphrased_predicted_class].item()
paraphrased_color = "red" if paraphrased_label == "Toxic" else "green"
paraphrased_toxicity_score = torch.softmax(paraphrased_logits, dim=1)[0][1].item()
paraphrased_toxicity_score = round(paraphrased_toxicity_score, 2)
paraphrased_bias_score = 0.01 if paraphrased_label == "Non-Toxic" else 0.15 # Placeholder
paraphrased_bias_score = round(paraphrased_bias_score, 2)
# Compute essential metrics
semantic_similarity = compute_semantic_similarity(comment, paraphrased_comment)
empathy_score = compute_empathy_score(paraphrased_comment)
bleu_score = compute_bleu_score(comment, paraphrased_comment)
rouge_scores = compute_rouge_score(comment, paraphrased_comment)
return (
f"Prediction: {label}", confidence, label_color, toxicity_score, bias_score,
paraphrased_comment, f"Prediction: {paraphrased_label}" if paraphrased_comment else None,
paraphrased_confidence, paraphrased_color, paraphrased_toxicity_score, paraphrased_bias_score,
semantic_similarity, empathy_score, bleu_score, rouge_scores
) |