Harshb11 commited on
Commit
6cce2e6
Β·
verified Β·
1 Parent(s): 7a7ff91

Update emotion_detection.py

Browse files
Files changed (1) hide show
  1. emotion_detection.py +10 -24
emotion_detection.py CHANGED
@@ -3,6 +3,7 @@ from transformers_interpret import SequenceClassificationExplainer
3
  import torch
4
  import pandas as pd
5
 
 
6
  class EmotionDetection:
7
  """
8
  Emotion Detection on text data.
@@ -24,10 +25,12 @@ class EmotionDetection:
24
  Parameters:
25
  text (str): The user input string to emotion justification
26
  Returns:
27
- html (html): html object for plotting emotion prediction justification
28
  """
 
29
  word_attributions = self.explainer(text)
30
  html = self.explainer.visualize("example.html")
 
31
  return html
32
 
33
  def classify(self, text):
@@ -43,28 +46,9 @@ class EmotionDetection:
43
  outputs = self.model(**tokens)
44
  probs = torch.nn.functional.softmax(outputs[0], dim=-1)
45
  probs = probs.mean(dim=0).detach().numpy()
 
 
46
 
47
- # Original labels from model
48
- original_labels = list(self.model.config.id2label.values())
49
-
50
- # Only keep the 4 specific emotions and map custom names
51
- desired_labels = ['joy', 'anger', 'sadness', 'optimism']
52
- custom_labels = {
53
- 'joy': 'Happiness 😊',
54
- 'anger': 'Anger 😑',
55
- 'sadness': 'Sadness 😒',
56
- 'optimism': 'Hopeful ✨'
57
- }
58
-
59
- filtered_probs = []
60
- filtered_labels = []
61
-
62
- for label, prob in zip(original_labels, probs):
63
- if label in desired_labels:
64
- filtered_probs.append(prob)
65
- filtered_labels.append(custom_labels[label])
66
-
67
- preds = pd.Series(filtered_probs, index=filtered_labels, name='Predicted Probability')
68
  return preds
69
 
70
  def run(self, text):
@@ -74,8 +58,10 @@ class EmotionDetection:
74
  text (str): The user input string to perform emotion classification on
75
  Returns:
76
  predictions (str): The predicted probabilities for emotion classes
77
- html (html): html object for plotting emotion prediction justification
78
  """
 
79
  preds = self.classify(text)
80
  html = self.justify(text)
81
- return preds, html
 
 
3
  import torch
4
  import pandas as pd
5
 
6
+
7
  class EmotionDetection:
8
  """
9
  Emotion Detection on text data.
 
25
  Parameters:
26
  text (str): The user input string to emotion justification
27
  Returns:
28
+ html (hmtl): html object for plotting emotion prediction justification
29
  """
30
+
31
  word_attributions = self.explainer(text)
32
  html = self.explainer.visualize("example.html")
33
+
34
  return html
35
 
36
  def classify(self, text):
 
46
  outputs = self.model(**tokens)
47
  probs = torch.nn.functional.softmax(outputs[0], dim=-1)
48
  probs = probs.mean(dim=0).detach().numpy()
49
+ labels = list(self.model.config.id2label.values())
50
+ preds = pd.Series(probs, index=labels, name='Predicted Probability')
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return preds
53
 
54
  def run(self, text):
 
58
  text (str): The user input string to perform emotion classification on
59
  Returns:
60
  predictions (str): The predicted probabilities for emotion classes
61
+ html (hmtl): html object for plotting emotion prediction justification
62
  """
63
+
64
  preds = self.classify(text)
65
  html = self.justify(text)
66
+
67
+ return preds, html