JanviMl commited on
Commit
bf242e6
·
verified ·
1 Parent(s): c586725

Update classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +63 -92
classifier.py CHANGED
@@ -1,98 +1,69 @@
1
- # classifier.py
2
- import torch
3
- import time
4
- from model_loader import classifier_model
5
- from paraphraser import paraphrase_comment
6
- from metrics import compute_semantic_similarity, compute_empathy_score, compute_bleu_score, compute_rouge_score
7
 
8
- def classify_toxic_comment(comment):
 
 
 
9
  """
10
- Classify a comment as toxic or non-toxic using the fine-tuned XLM-RoBERTa model.
11
- If toxic, paraphrase the comment, re-evaluate, and compute essential metrics.
12
- Returns the prediction label, confidence, color, toxicity score, bias score, paraphrased comment (if applicable), and its metrics.
13
  """
14
- start_total = time.time()
15
- print("Starting classification...")
16
-
17
- if not comment.strip():
18
- return "Error: Please enter a comment.", None, None, None, None, None, None, None, None, None, None, None, None
19
-
20
- # Access the model and tokenizer
21
- model = classifier_model.model
22
- tokenizer = classifier_model.tokenizer
23
-
24
- # Tokenize the input comment
25
- start_classification = time.time()
26
- inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
27
-
28
- # Run inference
29
- with torch.no_grad():
30
- outputs = model(**inputs)
31
- logits = outputs.logits
32
-
33
- # Get the predicted class (0 = non-toxic, 1 = toxic)
34
- predicted_class = torch.argmax(logits, dim=1).item()
35
- label = "Toxic" if predicted_class == 1 else "Non-Toxic"
36
- confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
37
- label_color = "red" if label == "Toxic" else "green"
38
-
39
- # Compute Toxicity Score (approximated as the probability of the toxic class)
40
- toxicity_score = torch.softmax(logits, dim=1)[0][1].item()
41
- toxicity_score = round(toxicity_score, 2)
42
-
43
- # Simulate Bias Score (placeholder)
44
- bias_score = 0.01 if label == "Non-Toxic" else 0.15
45
- bias_score = round(bias_score, 2)
46
- print(f"Classification took {time.time() - start_classification:.2f} seconds")
47
-
48
- # If the comment is toxic, paraphrase it and compute essential metrics
49
- paraphrased_comment = None
50
- paraphrased_prediction = None
51
- paraphrased_confidence = None
52
- paraphrased_color = None
53
- paraphrased_toxicity_score = None
54
- paraphrased_bias_score = None
55
- semantic_similarity = None
56
- empathy_score = None
57
- bleu_score = None
58
- rouge_scores = None
59
 
60
- if label == "Toxic":
61
- # Paraphrase the comment
62
- start_paraphrase = time.time()
63
- paraphrased_comment = paraphrase_comment(comment)
64
- print(f"Paraphrasing took {time.time() - start_paraphrase:.2f} seconds")
65
-
66
- # Re-evaluate the paraphrased comment
67
- start_reclassification = time.time()
68
- paraphrased_inputs = tokenizer(paraphrased_comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
69
- with torch.no_grad():
70
- paraphrased_outputs = model(**paraphrased_inputs)
71
- paraphrased_logits = paraphrased_outputs.logits
72
-
73
- paraphrased_predicted_class = torch.argmax(paraphrased_logits, dim=1).item()
74
- paraphrased_label = "Toxic" if paraphrased_predicted_class == 1 else "Non-Toxic"
75
- paraphrased_confidence = torch.softmax(paraphrased_logits, dim=1)[0][paraphrased_predicted_class].item()
76
- paraphrased_color = "red" if paraphrased_label == "Toxic" else "green"
77
- paraphrased_toxicity_score = torch.softmax(paraphrased_logits, dim=1)[0][1].item()
78
- paraphrased_toxicity_score = round(paraphrased_toxicity_score, 2)
79
- paraphrased_bias_score = 0.01 if paraphrased_label == "Non-Toxic" else 0.15 # Placeholder
80
- paraphrased_bias_score = round(paraphrased_bias_score, 2)
81
- print(f"Reclassification of paraphrased comment took {time.time() - start_reclassification:.2f} seconds")
82
-
83
- # Compute essential metrics
84
- start_metrics = time.time()
85
- semantic_similarity = compute_semantic_similarity(comment, paraphrased_comment)
86
- empathy_score = compute_empathy_score(paraphrased_comment)
87
- bleu_score = compute_bleu_score(comment, paraphrased_comment)
88
- rouge_scores = compute_rouge_score(comment, paraphrased_comment)
89
- print(f"Metrics computation took {time.time() - start_metrics:.2f} seconds")
90
 
91
- print(f"Total processing time: {time.time() - start_total:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- return (
94
- f"Prediction: {label}", confidence, label_color, toxicity_score, bias_score,
95
- paraphrased_comment, f"Prediction: {paraphrased_label}" if paraphrased_comment else None,
96
- paraphrased_confidence, paraphrased_color, paraphrased_toxicity_score, paraphrased_bias_score,
97
- semantic_similarity, empathy_score, bleu_score, rouge_scores
98
- )
 
 
 
 
 
 
 
 
 
 
 
1
+ # metrics.py
2
+ import nltk
3
+ from nltk.translate.bleu_score import sentence_bleu
4
+ from rouge_score import rouge_scorer
5
+ from model_loader import metrics_models
 
6
 
7
+ # Download required NLTK data
8
+ nltk.download('punkt')
9
+
10
+ def compute_semantic_similarity(original, paraphrased):
11
  """
12
+ Compute semantic similarity between the original and paraphrased comment using Sentence-BERT.
13
+ Returns a similarity score between 0 and 1.
 
14
  """
15
+ try:
16
+ sentence_bert = metrics_models.load_sentence_bert()
17
+ embeddings = sentence_bert.encode([original, paraphrased])
18
+ similarity = float(embeddings[0] @ embeddings[1].T)
19
+ return round(similarity, 2)
20
+ except Exception as e:
21
+ print(f"Error computing semantic similarity: {str(e)}")
22
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def compute_empathy_score(paraphrased):
25
+ """
26
+ Compute an empathy score for the paraphrased comment (placeholder).
27
+ Returns a score between 0 and 1.
28
+ """
29
+ try:
30
+ # Placeholder: Compute empathy based on word presence (e.g., "sorry", "understand")
31
+ empathy_words = ["sorry", "understand", "care", "help", "support"]
32
+ words = paraphrased.lower().split()
33
+ empathy_count = sum(1 for word in words if word in empathy_words)
34
+ score = empathy_count / len(words) if words else 0
35
+ return round(score, 2)
36
+ except Exception as e:
37
+ print(f"Error computing empathy score: {str(e)}")
38
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def compute_bleu_score(original, paraphrased):
41
+ """
42
+ Compute the BLEU score between the original and paraphrased comment.
43
+ Returns a score between 0 and 1.
44
+ """
45
+ try:
46
+ reference = [nltk.word_tokenize(original.lower())]
47
+ candidate = nltk.word_tokenize(paraphrased.lower())
48
+ score = sentence_bleu(reference, candidate, weights=(0.25, 0.25, 0.25, 0.25))
49
+ return round(score, 2)
50
+ except Exception as e:
51
+ print(f"Error computing BLEU score: {str(e)}")
52
+ return None
53
 
54
+ def compute_rouge_score(original, paraphrased):
55
+ """
56
+ Compute ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L) between the original and paraphrased comment.
57
+ Returns a dictionary with ROUGE scores.
58
+ """
59
+ try:
60
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
61
+ scores = scorer.score(original, paraphrased)
62
+ return {
63
+ 'rouge1': round(scores['rouge1'].fmeasure, 2),
64
+ 'rouge2': round(scores['rouge2'].fmeasure, 2),
65
+ 'rougeL': round(scores['rougeL'].fmeasure, 2)
66
+ }
67
+ except Exception as e:
68
+ print(f"Error computing ROUGE scores: {str(e)}")
69
+ return None