Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from smolagents import Tool | |
class SimpleSentimentTool(Tool): | |
name = "sentiment_analysis" | |
description = "This tool analyzes the sentiment of a given text." | |
inputs = { | |
"text": { | |
"type": "string", | |
"description": "The text to analyze for sentiment" | |
}, | |
"model_key": { | |
"type": "string", | |
"description": "The model to use for sentiment analysis", | |
"default": "oliverguhr/german-sentiment-bert", | |
"nullable": True | |
} | |
} | |
# Use a standard authorized type | |
output_type = "string" | |
# Available sentiment analysis models | |
models = { | |
"multilingual": "nlptown/bert-base-multilingual-uncased-sentiment", | |
"deberta": "microsoft/deberta-xlarge-mnli", | |
"distilbert": "distilbert-base-uncased-finetuned-sst-2-english", | |
"mobilebert": "lordtt13/emo-mobilebert", | |
"reviews": "juliensimon/reviews-sentiment-analysis", | |
"sbc": "sbcBI/sentiment_analysis_model", | |
"german": "oliverguhr/german-sentiment-bert" | |
} | |
def __init__(self, default_model="distilbert", preload=False): | |
"""Initialize with a default model. | |
Args: | |
default_model: The default model to use if no model is specified | |
preload: Whether to preload the default model at initialization | |
""" | |
super().__init__() | |
self.default_model = default_model | |
self._classifiers = {} | |
# Optionally preload the default model | |
if preload: | |
try: | |
self._get_classifier(self.models[default_model]) | |
except Exception as e: | |
print(f"Warning: Failed to preload model: {str(e)}") | |
def _get_classifier(self, model_id): | |
"""Get or create a classifier for the given model ID.""" | |
if model_id not in self._classifiers: | |
try: | |
print(f"Loading model: {model_id}") | |
self._classifiers[model_id] = pipeline( | |
"text-classification", | |
model=model_id, | |
top_k=None # Return all scores | |
) | |
except Exception as e: | |
print(f"Error loading model {model_id}: {str(e)}") | |
# Fall back to distilbert if available | |
if model_id != self.models["distilbert"]: | |
print("Falling back to distilbert model...") | |
return self._get_classifier(self.models["distilbert"]) | |
else: | |
# Last resort - if even distilbert fails | |
print("Critical error: Could not load default model") | |
raise RuntimeError(f"Failed to load any sentiment model: {str(e)}") | |
return self._classifiers[model_id] | |
def forward(self, text: str, model_key="oliverguhr/german-sentiment-bert"): | |
"""Process input text and return sentiment predictions.""" | |
try: | |
# Determine which model to use | |
model_key = model_key or self.default_model | |
model_id = self.models.get(model_key, self.models[self.default_model]) | |
# Get the classifier | |
classifier = self._get_classifier(model_id) | |
# Get predictions | |
prediction = classifier(text) | |
# Format as a dictionary | |
result = {} | |
for item in prediction[0]: | |
result[item['label']] = float(item['score']) | |
# Convert to JSON string for output | |
import json | |
return json.dumps(result, indent=2) | |
except Exception as e: | |
print(f"Error in sentiment analysis: {str(e)}") | |
return json.dumps({"error": str(e)}, indent=2) |