NLP_2k25_Project / emotion_detection.py
Harshb11's picture
Update emotion_detection.py
08d98c4 verified
raw
history blame
3.13 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer
import torch
import os
class EmotionDetection:
"""
Emotion Detection on text data.
Attributes:
tokenizer: Hugging Face Tokenizer instance
model: Hugging Face Sequence Classification model
explainer: SequenceClassificationExplainer instance for model interpretability
"""
def __init__(self):
hub_location = 'cardiffnlp/twitter-roberta-base-emotion'
self.tokenizer = AutoTokenizer.from_pretrained(hub_location)
self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)
self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)
# Emoji map for friendly display
self.emoji_map = {
"joy": "😊",
"anger": "😠",
"optimism": "😎",
"sadness": "😒"
}
# Simple explanation map
self.explanation_map = {
"joy": "The person is happy or excited.",
"anger": "The person is upset or angry.",
"optimism": "The person is feeling hopeful or positive.",
"sadness": "The person is feeling low or unhappy."
}
def justify(self, text):
"""
Generate HTML visualization of word attributions for emotion.
Parameters:
text (str): Input text
Returns:
html (str): HTML string with justification visualization
"""
word_attributions = self.explainer(text)
html_path = "justification_output.html"
self.explainer.visualize(html_path)
# Read from file
with open(html_path, "r", encoding="utf-8") as f:
html = f.read()
# Clean up file
os.remove(html_path)
return html
def classify(self, text):
"""
Classify the main emotion in the input text.
Parameters:
text (str): Input text
Returns:
result (str): Friendly output with emoji and short explanation
"""
tokens = self.tokenizer.encode_plus(text, return_tensors='pt')
outputs = self.model(**tokens)
probs = torch.nn.functional.softmax(outputs[0], dim=-1)
probs = probs.mean(dim=0).detach().numpy()
labels = list(self.model.config.id2label.values())
max_index = probs.argmax()
emotion = labels[max_index]
confidence = probs[max_index]
emoji = self.emoji_map.get(emotion, "")
explanation = self.explanation_map.get(emotion, "")
result = f"{emoji} **{emotion.capitalize()}** ({confidence:.1%})\n{explanation}"
return result
def run(self, text):
"""
Perform classification and justification.
Parameters:
text (str): Input text
Returns:
result (str): Emotion classification result
html (str): Justification HTML
"""
result = self.classify(text)
html = self.justify(text)
return result, html