NLP_2k25_Project / emotion_detection.py
Harshb11's picture
Update emotion_detection.py
11fa374 verified
raw
history blame
3.15 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer
import torch
class EmotionDetection:
"""
Emotion Detection on text data.
Attributes:
tokenizer: An instance of Hugging Face Tokenizer
model: An instance of Hugging Face Model
explainer: An instance of SequenceClassificationExplainer from Transformers interpret
"""
def __init__(self):
hub_location = 'cardiffnlp/twitter-roberta-base-emotion'
self.tokenizer = AutoTokenizer.from_pretrained(hub_location)
self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)
# Friendly emoji mapping
self.emoji_map = {
"joy": "😊",
"anger": "😠",
"optimism": "😎",
"sadness": "😒"
}
# Friendly explanation mapping
self.explanation_map = {
"joy": "The person is happy or excited.",
"anger": "The person is upset or angry.",
"optimism": "The person is feeling hopeful or positive.",
"sadness": "The person is feeling low or unhappy."
}
def justify(self, text):
"""
Get html annotation for displaying emotion justification over text.
Parameters:
text (str): The user input string for emotion justification
Returns:
html (str): html string for plotting emotion prediction justification
"""
word_attributions = self.explainer(text)
html = self.explainer.visualize(return_html=True) # Changed to return HTML string
return html
def classify(self, text):
"""
Recognize Emotion in text (simplified).
Parameters:
text (str): The user input string to perform emotion classification on
Returns:
result (str): User-friendly emotion label with emoji and explanation
"""
tokens = self.tokenizer.encode_plus(text, return_tensors='pt')
outputs = self.model(**tokens)
probs = torch.nn.functional.softmax(outputs[0], dim=-1)
probs = probs.mean(dim=0).detach().numpy()
labels = list(self.model.config.id2label.values())
max_index = probs.argmax()
emotion = labels[max_index]
confidence = probs[max_index]
emoji = self.emoji_map.get(emotion, "")
explanation = self.explanation_map.get(emotion, "")
result = f"{emoji} **{emotion.capitalize()}** ({confidence:.1%})\n{explanation}"
return result
def run(self, text):
"""
Classify and Justify Emotion in text.
Parameters:
text (str): The user input string to perform emotion classification on
Returns:
result (str): Friendly emotion classification output
html (str): HTML visualization string for model justification
"""
result = self.classify(text)
html = self.justify(text)
return result, html