Harshb11 commited on
Commit
7a7ff91
Β·
verified Β·
1 Parent(s): 67c9522

Update emotion_detection.py

Browse files
Files changed (1) hide show
  1. emotion_detection.py +32 -38
emotion_detection.py CHANGED
@@ -3,7 +3,6 @@ from transformers_interpret import SequenceClassificationExplainer
3
  import torch
4
  import pandas as pd
5
 
6
-
7
  class EmotionDetection:
8
  """
9
  Emotion Detection on text data.
@@ -25,51 +24,48 @@ class EmotionDetection:
25
  Parameters:
26
  text (str): The user input string to emotion justification
27
  Returns:
28
- html (hmtl): html object for plotting emotion prediction justification
29
  """
30
-
31
  word_attributions = self.explainer(text)
32
  html = self.explainer.visualize("example.html")
33
-
34
  return html
35
 
36
  def classify(self, text):
37
- """
38
- Recognize Emotion in text.
39
- Parameters:
40
- text (str): The user input string to perform emotion classification on
41
- Returns:
42
- predictions (str): The predicted probabilities for emotion classes
43
- """
44
-
45
- tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
46
- outputs = self.model(**tokens)
47
- probs = torch.nn.functional.softmax(outputs[0], dim=-1)
48
- probs = probs.mean(dim=0).detach().numpy()
49
 
50
- # Original labels from model
51
- original_labels = list(self.model.config.id2label.values())
 
 
52
 
53
- # Only keep the 4 specific emotions and map custom names
54
- desired_labels = ['joy', 'anger', 'sadness', 'optimism']
55
- custom_labels = {
56
- 'joy': 'Happiness 😊',
57
- 'anger': 'Anger 😑',
58
- 'sadness': 'Sadness 😒',
59
- 'optimism': 'Hopeful ✨'
60
- }
61
 
62
- filtered_probs = []
63
- filtered_labels = []
 
 
 
 
 
 
64
 
65
- for label, prob in zip(original_labels, probs):
66
- if label in desired_labels:
67
- filtered_probs.append(prob)
68
- filtered_labels.append(custom_labels[label])
69
 
70
- preds = pd.Series(filtered_probs, index=filtered_labels, name='Predicted Probability')
 
 
 
71
 
72
- return preds
 
73
 
74
  def run(self, text):
75
  """
@@ -78,10 +74,8 @@ class EmotionDetection:
78
  text (str): The user input string to perform emotion classification on
79
  Returns:
80
  predictions (str): The predicted probabilities for emotion classes
81
- html (hmtl): html object for plotting emotion prediction justification
82
  """
83
-
84
  preds = self.classify(text)
85
  html = self.justify(text)
86
-
87
- return preds, html
 
3
  import torch
4
  import pandas as pd
5
 
 
6
  class EmotionDetection:
7
  """
8
  Emotion Detection on text data.
 
24
  Parameters:
25
  text (str): The user input string to emotion justification
26
  Returns:
27
+ html (html): html object for plotting emotion prediction justification
28
  """
 
29
  word_attributions = self.explainer(text)
30
  html = self.explainer.visualize("example.html")
 
31
  return html
32
 
33
  def classify(self, text):
34
+ """
35
+ Recognize Emotion in text.
36
+ Parameters:
37
+ text (str): The user input string to perform emotion classification on
38
+ Returns:
39
+ predictions (str): The predicted probabilities for emotion classes
40
+ """
 
 
 
 
 
41
 
42
+ tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
43
+ outputs = self.model(**tokens)
44
+ probs = torch.nn.functional.softmax(outputs[0], dim=-1)
45
+ probs = probs.mean(dim=0).detach().numpy()
46
 
47
+ # Original labels from model
48
+ original_labels = list(self.model.config.id2label.values())
 
 
 
 
 
 
49
 
50
+ # Only keep the 4 specific emotions and map custom names
51
+ desired_labels = ['joy', 'anger', 'sadness', 'optimism']
52
+ custom_labels = {
53
+ 'joy': 'Happiness 😊',
54
+ 'anger': 'Anger 😑',
55
+ 'sadness': 'Sadness 😒',
56
+ 'optimism': 'Hopeful ✨'
57
+ }
58
 
59
+ filtered_probs = []
60
+ filtered_labels = []
 
 
61
 
62
+ for label, prob in zip(original_labels, probs):
63
+ if label in desired_labels:
64
+ filtered_probs.append(prob)
65
+ filtered_labels.append(custom_labels[label])
66
 
67
+ preds = pd.Series(filtered_probs, index=filtered_labels, name='Predicted Probability')
68
+ return preds
69
 
70
  def run(self, text):
71
  """
 
74
  text (str): The user input string to perform emotion classification on
75
  Returns:
76
  predictions (str): The predicted probabilities for emotion classes
77
+ html (html): html object for plotting emotion prediction justification
78
  """
 
79
  preds = self.classify(text)
80
  html = self.justify(text)
81
+ return preds, html