JanviMl commited on
Commit
ca75f71
·
verified ·
1 Parent(s): 2550e3f

Update classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +88 -63
classifier.py CHANGED
@@ -1,69 +1,94 @@
1
- # metrics.py
2
- import nltk
3
- from nltk.translate.bleu_score import sentence_bleu
4
- from rouge_score import rouge_scorer
5
- from model_loader import metrics_models
 
6
 
7
- # Download required NLTK data
8
- nltk.download('punkt')
9
-
10
- def compute_semantic_similarity(original, paraphrased):
11
  """
12
- Compute semantic similarity between the original and paraphrased comment using Sentence-BERT.
13
- Returns a similarity score between 0 and 1.
 
14
  """
15
- try:
16
- sentence_bert = metrics_models.load_sentence_bert()
17
- embeddings = sentence_bert.encode([original, paraphrased])
18
- similarity = float(embeddings[0] @ embeddings[1].T)
19
- return round(similarity, 2)
20
- except Exception as e:
21
- print(f"Error computing semantic similarity: {str(e)}")
22
- return None
23
 
24
- def compute_empathy_score(paraphrased):
25
- """
26
- Compute an empathy score for the paraphrased comment (placeholder).
27
- Returns a score between 0 and 1.
28
- """
29
- try:
30
- # Placeholder: Compute empathy based on word presence (e.g., "sorry", "understand")
31
- empathy_words = ["sorry", "understand", "care", "help", "support"]
32
- words = paraphrased.lower().split()
33
- empathy_count = sum(1 for word in words if word in empathy_words)
34
- score = empathy_count / len(words) if words else 0
35
- return round(score, 2)
36
- except Exception as e:
37
- print(f"Error computing empathy score: {str(e)}")
38
- return None
39
 
40
- def compute_bleu_score(original, paraphrased):
41
- """
42
- Compute the BLEU score between the original and paraphrased comment.
43
- Returns a score between 0 and 1.
44
- """
45
- try:
46
- reference = [nltk.word_tokenize(original.lower())]
47
- candidate = nltk.word_tokenize(paraphrased.lower())
48
- score = sentence_bleu(reference, candidate, weights=(0.25, 0.25, 0.25, 0.25))
49
- return round(score, 2)
50
- except Exception as e:
51
- print(f"Error computing BLEU score: {str(e)}")
52
- return None
53
 
54
- def compute_rouge_score(original, paraphrased):
55
- """
56
- Compute ROUGE scores (ROUGE-1, ROUGE-2, ROUGE-L) between the original and paraphrased comment.
57
- Returns a dictionary with ROUGE scores.
58
- """
59
- try:
60
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
61
- scores = scorer.score(original, paraphrased)
62
- return {
63
- 'rouge1': round(scores['rouge1'].fmeasure, 2),
64
- 'rouge2': round(scores['rouge2'].fmeasure, 2),
65
- 'rougeL': round(scores['rougeL'].fmeasure, 2)
66
- }
67
- except Exception as e:
68
- print(f"Error computing ROUGE scores: {str(e)}")
69
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # classifier.py
2
+ import torch
3
+ import time
4
+ from model_loader import classifier_model
5
+ from paraphraser import paraphrase_comment
6
+ from metrics import compute_semantic_similarity, compute_empathy_score
7
 
8
+ def classify_toxic_comment(comment):
 
 
 
9
  """
10
+ Classify a comment as toxic or non-toxic using the fine-tuned XLM-RoBERTa model.
11
+ If toxic, paraphrase the comment, re-evaluate, and compute essential metrics.
12
+ Returns the prediction label, confidence, color, toxicity score, bias score, paraphrased comment (if applicable), and its metrics.
13
  """
14
+ start_total = time.time()
15
+ print("Starting classification...")
 
 
 
 
 
 
16
 
17
+ if not comment.strip():
18
+ return "Error: Please enter a comment.", None, None, None, None, None, None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Access the model and tokenizer
21
+ model = classifier_model.model
22
+ tokenizer = classifier_model.tokenizer
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Tokenize the input comment
25
+ start_classification = time.time()
26
+ inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
27
+
28
+ # Run inference
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits
32
+
33
+ # Get the predicted class (0 = non-toxic, 1 = toxic)
34
+ predicted_class = torch.argmax(logits, dim=1).item()
35
+ label = "Toxic" if predicted_class == 1 else "Non-Toxic"
36
+ confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
37
+ label_color = "red" if label == "Toxic" else "green"
38
+
39
+ # Compute Toxicity Score (approximated as the probability of the toxic class)
40
+ toxicity_score = torch.softmax(logits, dim=1)[0][1].item()
41
+ toxicity_score = round(toxicity_score, 2)
42
+
43
+ # Simulate Bias Score (placeholder)
44
+ bias_score = 0.01 if label == "Non-Toxic" else 0.15
45
+ bias_score = round(bias_score, 2)
46
+ print(f"Classification took {time.time() - start_classification:.2f} seconds")
47
+
48
+ # If the comment is toxic, paraphrase it and compute essential metrics
49
+ paraphrased_comment = None
50
+ paraphrased_prediction = None
51
+ paraphrased_confidence = None
52
+ paraphrased_color = None
53
+ paraphrased_toxicity_score = None
54
+ paraphrased_bias_score = None
55
+ semantic_similarity = None
56
+ empathy_score = None
57
+
58
+ if label == "Toxic":
59
+ # Paraphrase the comment
60
+ start_paraphrase = time.time()
61
+ paraphrased_comment = paraphrase_comment(comment)
62
+ print(f"Paraphrasing took {time.time() - start_paraphrase:.2f} seconds")
63
+
64
+ # Re-evaluate the paraphrased comment
65
+ start_reclassification = time.time()
66
+ paraphrased_inputs = tokenizer(paraphrased_comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
67
+ with torch.no_grad():
68
+ paraphrased_outputs = model(**paraphrased_inputs)
69
+ paraphrased_logits = paraphrased_outputs.logits
70
+
71
+ paraphrased_predicted_class = torch.argmax(paraphrased_logits, dim=1).item()
72
+ paraphrased_label = "Toxic" if paraphrased_predicted_class == 1 else "Non-Toxic"
73
+ paraphrased_confidence = torch.softmax(paraphrased_logits, dim=1)[0][paraphrased_predicted_class].item()
74
+ paraphrased_color = "red" if paraphrased_label == "Toxic" else "green"
75
+ paraphrased_toxicity_score = torch.softmax(paraphrased_logits, dim=1)[0][1].item()
76
+ paraphrased_toxicity_score = round(paraphrased_toxicity_score, 2)
77
+ paraphrased_bias_score = 0.01 if paraphrased_label == "Non-Toxic" else 0.15 # Placeholder
78
+ paraphrased_bias_score = round(paraphrased_bias_score, 2)
79
+ print(f"Reclassification of paraphrased comment took {time.time() - start_reclassification:.2f} seconds")
80
+
81
+ # Compute essential metrics
82
+ start_metrics = time.time()
83
+ semantic_similarity = compute_semantic_similarity(comment, paraphrased_comment)
84
+ empathy_score = compute_empathy_score(paraphrased_comment)
85
+ print(f"Metrics computation took {time.time() - start_metrics:.2f} seconds")
86
+
87
+ print(f"Total processing time: {time.time() - start_total:.2f} seconds")
88
+
89
+ return (
90
+ f"Prediction: {label}", confidence, label_color, toxicity_score, bias_score,
91
+ paraphrased_comment, f"Prediction: {paraphrased_label}" if paraphrased_comment else None,
92
+ paraphrased_confidence, paraphrased_color, paraphrased_toxicity_score, paraphrased_bias_score,
93
+ semantic_similarity, empathy_score
94
+ )