sentiment-tool / sentiment_analysis.py
Chris4K's picture
Update sentiment_analysis.py
fd41f0a verified
raw
history blame
2.8 kB
import gradio as gr
from transformers import pipeline
from transformers import Tool
class SentimentAnalysisTool(Tool):
name = "sentiment_analysis"
description = "This tool analyses the sentiment of a given text."
inputs = {"text": {"type": "text", "description": "The text to analyze"}}
outputs = {"json": {"type": "json", "description": "Sentiment analysis results"}}
# Available sentiment analysis models
models = {
"multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
"deberta": "microsoft/deberta-xlarge-mnli",
"distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
"mobilebert": "lordtt13/emo-mobilebert",
"reviews": "juliensimon/reviews-sentiment-analysis",
"sbc": "sbcBI/sentiment_analysis_model",
"german": "oliverguhr/german-sentiment-bert"
}
def __init__(self, default_model="distilbert"):
"""Initialize with a default model."""
self.default_model = default_model
# Pre-load the default model to speed up first inference
self._classifiers = {}
self.get_classifier(self.models[default_model])
def __call__(self, text: str):
"""Process input text and return sentiment predictions."""
return self.predict(text)
def parse_output(self, output_json):
"""Parse model output into a list of (label, score) tuples."""
list_pred = []
for i in range(len(output_json[0])):
label = output_json[0][i]['label']
score = output_json[0][i]['score']
list_pred.append((label, score))
return list_pred
def get_classifier(self, model_id):
"""Get or create a classifier for the given model ID."""
if model_id not in self._classifiers:
self._classifiers[model_id] = pipeline(
"text-classification",
model=model_id,
return_all_scores=True
)
return self._classifiers[model_id]
def predict(self, text, model_key=None):
"""Make predictions using the specified or default model."""
model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
classifier = self.get_classifier(model_id)
prediction = classifier(text)
return self.parse_output(prediction)
# For standalone testing
if __name__ == "__main__":
# Create an instance of the SentimentAnalysisTool class
sentiment_analysis_tool = SentimentAnalysisTool()
# Test with a sample text
test_text = "I really enjoyed this product. It exceeded my expectations!"
result = sentiment_analysis_tool(test_text)
print(f"Input: {test_text}")
print(f"Result: {result}")