Harshb11 commited on
Commit
11fa374
Β·
verified Β·
1 Parent(s): 116ee61

Update emotion_detection.py

Browse files
Files changed (1) hide show
  1. emotion_detection.py +35 -25
emotion_detection.py CHANGED
@@ -1,13 +1,11 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from transformers_interpret import SequenceClassificationExplainer
3
  import torch
4
- import pandas as pd
5
 
6
 
7
  class EmotionDetection:
8
  """
9
  Emotion Detection on text data.
10
-
11
  Attributes:
12
  tokenizer: An instance of Hugging Face Tokenizer
13
  model: An instance of Hugging Face Model
@@ -20,55 +18,67 @@ class EmotionDetection:
20
  self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
21
  self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def justify(self, text):
24
  """
25
  Get html annotation for displaying emotion justification over text.
26
-
27
  Parameters:
28
- text (str): The user input string to emotion justification
29
-
30
  Returns:
31
- html (hmtl): html object for plotting emotion prediction justification
32
  """
33
-
34
  word_attributions = self.explainer(text)
35
- html = self.explainer.visualize("example.html")
36
-
37
  return html
38
 
39
  def classify(self, text):
40
  """
41
- Recognize Emotion in text.
42
-
43
  Parameters:
44
  text (str): The user input string to perform emotion classification on
45
-
46
  Returns:
47
- predictions (str): The predicted probabilities for emotion classes
48
  """
49
-
50
- tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
51
  outputs = self.model(**tokens)
52
  probs = torch.nn.functional.softmax(outputs[0], dim=-1)
53
  probs = probs.mean(dim=0).detach().numpy()
 
54
  labels = list(self.model.config.id2label.values())
55
- preds = pd.Series(probs, index=labels, name='Predicted Probability')
 
 
56
 
57
- return preds
 
 
 
 
58
 
59
  def run(self, text):
60
  """
61
  Classify and Justify Emotion in text.
62
-
63
  Parameters:
64
  text (str): The user input string to perform emotion classification on
65
-
66
  Returns:
67
- predictions (str): The predicted probabilities for emotion classes
68
- html (hmtl): html object for plotting emotion prediction justification
69
  """
70
-
71
- preds = self.classify(text)
72
  html = self.justify(text)
73
-
74
- return preds, html
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from transformers_interpret import SequenceClassificationExplainer
3
  import torch
 
4
 
5
 
6
  class EmotionDetection:
7
  """
8
  Emotion Detection on text data.
 
9
  Attributes:
10
  tokenizer: An instance of Hugging Face Tokenizer
11
  model: An instance of Hugging Face Model
 
18
  self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
19
  self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)
20
 
21
+ # Friendly emoji mapping
22
+ self.emoji_map = {
23
+ "joy": "😊",
24
+ "anger": "😠",
25
+ "optimism": "😎",
26
+ "sadness": "😒"
27
+ }
28
+
29
+ # Friendly explanation mapping
30
+ self.explanation_map = {
31
+ "joy": "The person is happy or excited.",
32
+ "anger": "The person is upset or angry.",
33
+ "optimism": "The person is feeling hopeful or positive.",
34
+ "sadness": "The person is feeling low or unhappy."
35
+ }
36
+
37
  def justify(self, text):
38
  """
39
  Get html annotation for displaying emotion justification over text.
 
40
  Parameters:
41
+ text (str): The user input string for emotion justification
 
42
  Returns:
43
+ html (str): html string for plotting emotion prediction justification
44
  """
 
45
  word_attributions = self.explainer(text)
46
+ html = self.explainer.visualize(return_html=True) # Changed to return HTML string
 
47
  return html
48
 
49
  def classify(self, text):
50
  """
51
+ Recognize Emotion in text (simplified).
 
52
  Parameters:
53
  text (str): The user input string to perform emotion classification on
 
54
  Returns:
55
+ result (str): User-friendly emotion label with emoji and explanation
56
  """
57
+ tokens = self.tokenizer.encode_plus(text, return_tensors='pt')
 
58
  outputs = self.model(**tokens)
59
  probs = torch.nn.functional.softmax(outputs[0], dim=-1)
60
  probs = probs.mean(dim=0).detach().numpy()
61
+
62
  labels = list(self.model.config.id2label.values())
63
+ max_index = probs.argmax()
64
+ emotion = labels[max_index]
65
+ confidence = probs[max_index]
66
 
67
+ emoji = self.emoji_map.get(emotion, "")
68
+ explanation = self.explanation_map.get(emotion, "")
69
+
70
+ result = f"{emoji} **{emotion.capitalize()}** ({confidence:.1%})\n{explanation}"
71
+ return result
72
 
73
  def run(self, text):
74
  """
75
  Classify and Justify Emotion in text.
 
76
  Parameters:
77
  text (str): The user input string to perform emotion classification on
 
78
  Returns:
79
+ result (str): Friendly emotion classification output
80
+ html (str): HTML visualization string for model justification
81
  """
82
+ result = self.classify(text)
 
83
  html = self.justify(text)
84
+ return result, html