File size: 3,132 Bytes
cd85011
 
 
08d98c4
cd85011
 
 
 
 
 
08d98c4
 
 
cd85011
 
 
 
 
 
 
 
08d98c4
11fa374
 
 
 
 
 
 
08d98c4
11fa374
 
 
 
 
 
 
cd85011
 
08d98c4
cd85011
08d98c4
cd85011
08d98c4
cd85011
 
08d98c4
 
 
 
 
 
 
 
 
cd85011
 
 
 
08d98c4
cd85011
08d98c4
cd85011
08d98c4
cd85011
11fa374
cd85011
 
 
11fa374
cd85011
11fa374
 
 
cd85011
11fa374
 
 
 
 
cd85011
 
 
08d98c4
cd85011
08d98c4
cd85011
08d98c4
 
cd85011
11fa374
cd85011
11fa374
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer
import torch
import os


class EmotionDetection:
    """
    Emotion Detection on text data.
    Attributes:
        tokenizer: Hugging Face Tokenizer instance
        model: Hugging Face Sequence Classification model
        explainer: SequenceClassificationExplainer instance for model interpretability
    """

    def __init__(self):
        hub_location = 'cardiffnlp/twitter-roberta-base-emotion'
        self.tokenizer = AutoTokenizer.from_pretrained(hub_location)
        self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
        self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)

        # Emoji map for friendly display
        self.emoji_map = {
            "joy": "😊",
            "anger": "😠",
            "optimism": "😎",
            "sadness": "😒"
        }

        # Simple explanation map
        self.explanation_map = {
            "joy": "The person is happy or excited.",
            "anger": "The person is upset or angry.",
            "optimism": "The person is feeling hopeful or positive.",
            "sadness": "The person is feeling low or unhappy."
        }

    def justify(self, text):
        """
        Generate HTML visualization of word attributions for emotion.
        Parameters:
            text (str): Input text
        Returns:
            html (str): HTML string with justification visualization
        """
        word_attributions = self.explainer(text)
        html_path = "justification_output.html"
        self.explainer.visualize(html_path)

        # Read from file
        with open(html_path, "r", encoding="utf-8") as f:
            html = f.read()

        # Clean up file
        os.remove(html_path)
        return html

    def classify(self, text):
        """
        Classify the main emotion in the input text.
        Parameters:
            text (str): Input text
        Returns:
            result (str): Friendly output with emoji and short explanation
        """
        tokens = self.tokenizer.encode_plus(text, return_tensors='pt')
        outputs = self.model(**tokens)
        probs = torch.nn.functional.softmax(outputs[0], dim=-1)
        probs = probs.mean(dim=0).detach().numpy()

        labels = list(self.model.config.id2label.values())
        max_index = probs.argmax()
        emotion = labels[max_index]
        confidence = probs[max_index]

        emoji = self.emoji_map.get(emotion, "")
        explanation = self.explanation_map.get(emotion, "")

        result = f"{emoji} **{emotion.capitalize()}** ({confidence:.1%})\n{explanation}"
        return result

    def run(self, text):
        """
        Perform classification and justification.
        Parameters:
            text (str): Input text
        Returns:
            result (str): Emotion classification result
            html (str): Justification HTML
        """
        result = self.classify(text)
        html = self.justify(text)
        return result, html