import gradio as gr from transformers import pipeline from smolagents import Tool class SimpleSentimentTool(Tool): name = "sentiment_analysis" description = "This tool analyzes the sentiment of a given text." inputs = { "text": { "type": "string", "description": "The text to analyze for sentiment" }, "model_key": { "type": "string", "description": "The model to use for sentiment analysis", "default": "oliverguhr/german-sentiment-bert", "nullable": True } } # Use a standard authorized type output_type = "string" # Available sentiment analysis models models = { "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment", "deberta": "microsoft/deberta-xlarge-mnli", "distilbert": "distilbert-base-uncased-finetuned-sst-2-english", "mobilebert": "lordtt13/emo-mobilebert", "reviews": "juliensimon/reviews-sentiment-analysis", "sbc": "sbcBI/sentiment_analysis_model", "german": "oliverguhr/german-sentiment-bert" } def __init__(self, default_model="distilbert", preload=False): """Initialize with a default model. Args: default_model: The default model to use if no model is specified preload: Whether to preload the default model at initialization """ super().__init__() self.default_model = default_model self._classifiers = {} # Optionally preload the default model if preload: try: self._get_classifier(self.models[default_model]) except Exception as e: print(f"Warning: Failed to preload model: {str(e)}") def _get_classifier(self, model_id): """Get or create a classifier for the given model ID.""" if model_id not in self._classifiers: try: print(f"Loading model: {model_id}") self._classifiers[model_id] = pipeline( "text-classification", model=model_id, top_k=None # Return all scores ) except Exception as e: print(f"Error loading model {model_id}: {str(e)}") # Fall back to distilbert if available if model_id != self.models["distilbert"]: print("Falling back to distilbert model...") return self._get_classifier(self.models["distilbert"]) else: # Last resort - if even distilbert fails print("Critical error: Could not load default model") raise RuntimeError(f"Failed to load any sentiment model: {str(e)}") return self._classifiers[model_id] def forward(self, text: str, model_key="oliverguhr/german-sentiment-bert"): """Process input text and return sentiment predictions.""" try: # Determine which model to use model_key = model_key or self.default_model model_id = self.models.get(model_key, self.models[self.default_model]) # Get the classifier classifier = self._get_classifier(model_id) # Get predictions prediction = classifier(text) # Format as a dictionary result = {} for item in prediction[0]: result[item['label']] = float(item['score']) # Convert to JSON string for output import json return json.dumps(result, indent=2) except Exception as e: print(f"Error in sentiment analysis: {str(e)}") return json.dumps({"error": str(e)}, indent=2)