sentiment-tool / simple_sentiment.py
Chris4K's picture
Update simple_sentiment.py
1dd56ad verified
import gradio as gr
from transformers import pipeline
from smolagents import Tool
class SimpleSentimentTool(Tool):
name = "sentiment_analysis"
description = "This tool analyzes the sentiment of a given text."
inputs = {
"text": {
"type": "string",
"description": "The text to analyze for sentiment"
},
"model_key": {
"type": "string",
"description": "The model to use for sentiment analysis",
"default": "oliverguhr/german-sentiment-bert",
"nullable": True
}
}
# Use a standard authorized type
output_type = "string"
# Available sentiment analysis models
models = {
"multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
"deberta": "microsoft/deberta-xlarge-mnli",
"distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
"mobilebert": "lordtt13/emo-mobilebert",
"reviews": "juliensimon/reviews-sentiment-analysis",
"sbc": "sbcBI/sentiment_analysis_model",
"german": "oliverguhr/german-sentiment-bert"
}
def __init__(self, default_model="distilbert", preload=False):
"""Initialize with a default model.
Args:
default_model: The default model to use if no model is specified
preload: Whether to preload the default model at initialization
"""
super().__init__()
self.default_model = default_model
self._classifiers = {}
# Optionally preload the default model
if preload:
try:
self._get_classifier(self.models[default_model])
except Exception as e:
print(f"Warning: Failed to preload model: {str(e)}")
def _get_classifier(self, model_id):
"""Get or create a classifier for the given model ID."""
if model_id not in self._classifiers:
try:
print(f"Loading model: {model_id}")
self._classifiers[model_id] = pipeline(
"text-classification",
model=model_id,
top_k=None # Return all scores
)
except Exception as e:
print(f"Error loading model {model_id}: {str(e)}")
# Fall back to distilbert if available
if model_id != self.models["distilbert"]:
print("Falling back to distilbert model...")
return self._get_classifier(self.models["distilbert"])
else:
# Last resort - if even distilbert fails
print("Critical error: Could not load default model")
raise RuntimeError(f"Failed to load any sentiment model: {str(e)}")
return self._classifiers[model_id]
def forward(self, text: str, model_key="oliverguhr/german-sentiment-bert"):
"""Process input text and return sentiment predictions."""
try:
# Determine which model to use
model_key = model_key or self.default_model
model_id = self.models.get(model_key, self.models[self.default_model])
# Get the classifier
classifier = self._get_classifier(model_id)
# Get predictions
prediction = classifier(text)
# Format as a dictionary
result = {}
for item in prediction[0]:
result[item['label']] = float(item['score'])
# Convert to JSON string for output
import json
return json.dumps(result, indent=2)
except Exception as e:
print(f"Error in sentiment analysis: {str(e)}")
return json.dumps({"error": str(e)}, indent=2)