Harshb11 commited on
Commit
2debc3e
Β·
verified Β·
1 Parent(s): 08d98c4

Update emotion_detection.py

Browse files
Files changed (1) hide show
  1. emotion_detection.py +24 -50
emotion_detection.py CHANGED
@@ -1,16 +1,16 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from transformers_interpret import SequenceClassificationExplainer
3
  import torch
4
- import os
5
 
6
 
7
  class EmotionDetection:
8
  """
9
  Emotion Detection on text data.
10
  Attributes:
11
- tokenizer: Hugging Face Tokenizer instance
12
- model: Hugging Face Sequence Classification model
13
- explainer: SequenceClassificationExplainer instance for model interpretability
14
  """
15
 
16
  def __init__(self):
@@ -19,75 +19,49 @@ class EmotionDetection:
19
  self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
20
  self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)
21
 
22
- # Emoji map for friendly display
23
- self.emoji_map = {
24
- "joy": "😊",
25
- "anger": "😠",
26
- "optimism": "😎",
27
- "sadness": "😒"
28
- }
29
-
30
- # Simple explanation map
31
- self.explanation_map = {
32
- "joy": "The person is happy or excited.",
33
- "anger": "The person is upset or angry.",
34
- "optimism": "The person is feeling hopeful or positive.",
35
- "sadness": "The person is feeling low or unhappy."
36
- }
37
-
38
  def justify(self, text):
39
  """
40
- Generate HTML visualization of word attributions for emotion.
41
  Parameters:
42
- text (str): Input text
43
  Returns:
44
- html (str): HTML string with justification visualization
45
  """
46
- word_attributions = self.explainer(text)
47
- html_path = "justification_output.html"
48
- self.explainer.visualize(html_path)
49
 
50
- # Read from file
51
- with open(html_path, "r", encoding="utf-8") as f:
52
- html = f.read()
53
 
54
- # Clean up file
55
- os.remove(html_path)
56
  return html
57
 
58
  def classify(self, text):
59
  """
60
- Classify the main emotion in the input text.
61
  Parameters:
62
- text (str): Input text
63
  Returns:
64
- result (str): Friendly output with emoji and short explanation
65
  """
66
- tokens = self.tokenizer.encode_plus(text, return_tensors='pt')
 
67
  outputs = self.model(**tokens)
68
  probs = torch.nn.functional.softmax(outputs[0], dim=-1)
69
  probs = probs.mean(dim=0).detach().numpy()
70
-
71
  labels = list(self.model.config.id2label.values())
72
- max_index = probs.argmax()
73
- emotion = labels[max_index]
74
- confidence = probs[max_index]
75
-
76
- emoji = self.emoji_map.get(emotion, "")
77
- explanation = self.explanation_map.get(emotion, "")
78
 
79
- result = f"{emoji} **{emotion.capitalize()}** ({confidence:.1%})\n{explanation}"
80
- return result
81
 
82
  def run(self, text):
83
  """
84
- Perform classification and justification.
85
  Parameters:
86
- text (str): Input text
87
  Returns:
88
- result (str): Emotion classification result
89
- html (str): Justification HTML
90
  """
91
- result = self.classify(text)
 
92
  html = self.justify(text)
93
- return result, html
 
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from transformers_interpret import SequenceClassificationExplainer
3
  import torch
4
+ import pandas as pd
5
 
6
 
7
  class EmotionDetection:
8
  """
9
  Emotion Detection on text data.
10
  Attributes:
11
+ tokenizer: An instance of Hugging Face Tokenizer
12
+ model: An instance of Hugging Face Model
13
+ explainer: An instance of SequenceClassificationExplainer from Transformers interpret
14
  """
15
 
16
  def __init__(self):
 
19
  self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
20
  self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def justify(self, text):
23
  """
24
+ Get html annotation for displaying emotion justification over text.
25
  Parameters:
26
+ text (str): The user input string to emotion justification
27
  Returns:
28
+ html (hmtl): html object for plotting emotion prediction justification
29
  """
 
 
 
30
 
31
+ word_attributions = self.explainer(text)
32
+ html = self.explainer.visualize("example.html")
 
33
 
 
 
34
  return html
35
 
36
  def classify(self, text):
37
  """
38
+ Recognize Emotion in text.
39
  Parameters:
40
+ text (str): The user input string to perform emotion classification on
41
  Returns:
42
+ predictions (str): The predicted probabilities for emotion classes
43
  """
44
+
45
+ tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
46
  outputs = self.model(**tokens)
47
  probs = torch.nn.functional.softmax(outputs[0], dim=-1)
48
  probs = probs.mean(dim=0).detach().numpy()
 
49
  labels = list(self.model.config.id2label.values())
50
+ preds = pd.Series(probs, index=labels, name='Predicted Probability')
 
 
 
 
 
51
 
52
+ return preds
 
53
 
54
  def run(self, text):
55
  """
56
+ Classify and Justify Emotion in text.
57
  Parameters:
58
+ text (str): The user input string to perform emotion classification on
59
  Returns:
60
+ predictions (str): The predicted probabilities for emotion classes
61
+ html (hmtl): html object for plotting emotion prediction justification
62
  """
63
+
64
+ preds = self.classify(text)
65
  html = self.justify(text)
66
+
67
+ return preds, html