Harshb11 commited on
Commit
67c9522
Β·
verified Β·
1 Parent(s): 264f686

Update emotion_detection.py

Browse files
Files changed (1) hide show
  1. emotion_detection.py +34 -14
emotion_detection.py CHANGED
@@ -34,22 +34,42 @@ class EmotionDetection:
34
  return html
35
 
36
  def classify(self, text):
37
- """
38
- Recognize Emotion in text.
39
- Parameters:
40
- text (str): The user input string to perform emotion classification on
41
- Returns:
42
- predictions (str): The predicted probabilities for emotion classes
43
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
46
- outputs = self.model(**tokens)
47
- probs = torch.nn.functional.softmax(outputs[0], dim=-1)
48
- probs = probs.mean(dim=0).detach().numpy()
49
- labels = list(self.model.config.id2label.values())
50
- preds = pd.Series(probs, index=labels, name='Predicted Probability')
51
 
52
- return preds
53
 
54
  def run(self, text):
55
  """
 
34
  return html
35
 
36
  def classify(self, text):
37
+ """
38
+ Recognize Emotion in text.
39
+ Parameters:
40
+ text (str): The user input string to perform emotion classification on
41
+ Returns:
42
+ predictions (str): The predicted probabilities for emotion classes
43
+ """
44
+
45
+ tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
46
+ outputs = self.model(**tokens)
47
+ probs = torch.nn.functional.softmax(outputs[0], dim=-1)
48
+ probs = probs.mean(dim=0).detach().numpy()
49
+
50
+ # Original labels from model
51
+ original_labels = list(self.model.config.id2label.values())
52
+
53
+ # Only keep the 4 specific emotions and map custom names
54
+ desired_labels = ['joy', 'anger', 'sadness', 'optimism']
55
+ custom_labels = {
56
+ 'joy': 'Happiness 😊',
57
+ 'anger': 'Anger 😑',
58
+ 'sadness': 'Sadness 😒',
59
+ 'optimism': 'Hopeful ✨'
60
+ }
61
+
62
+ filtered_probs = []
63
+ filtered_labels = []
64
+
65
+ for label, prob in zip(original_labels, probs):
66
+ if label in desired_labels:
67
+ filtered_probs.append(prob)
68
+ filtered_labels.append(custom_labels[label])
69
 
70
+ preds = pd.Series(filtered_probs, index=filtered_labels, name='Predicted Probability')
 
 
 
 
 
71
 
72
+ return preds
73
 
74
  def run(self, text):
75
  """