import gradio as gr from transformers import pipeline from transformers import Tool class SentimentAnalysisTool(Tool): name = "sentiment_analysis" description = "This tool analyses the sentiment of a given text." inputs = {"text": {"type": "text", "description": "The text to analyze"}} outputs = {"json": {"type": "json", "description": "Sentiment analysis results"}} # Available sentiment analysis models models = { "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment", "deberta": "microsoft/deberta-xlarge-mnli", "distilbert": "distilbert-base-uncased-finetuned-sst-2-english", "mobilebert": "lordtt13/emo-mobilebert", "reviews": "juliensimon/reviews-sentiment-analysis", "sbc": "sbcBI/sentiment_analysis_model", "german": "oliverguhr/german-sentiment-bert" } def __init__(self, default_model="distilbert"): """Initialize with a default model.""" self.default_model = default_model # Pre-load the default model to speed up first inference self._classifiers = {} self.get_classifier(self.models[default_model]) def __call__(self, text: str): """Process input text and return sentiment predictions.""" return self.predict(text) def parse_output(self, output_json): """Parse model output into a list of (label, score) tuples.""" list_pred = [] for i in range(len(output_json[0])): label = output_json[0][i]['label'] score = output_json[0][i]['score'] list_pred.append((label, score)) return list_pred def get_classifier(self, model_id): """Get or create a classifier for the given model ID.""" if model_id not in self._classifiers: self._classifiers[model_id] = pipeline( "text-classification", model=model_id, return_all_scores=True ) return self._classifiers[model_id] def predict(self, text, model_key=None): """Make predictions using the specified or default model.""" model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model] classifier = self.get_classifier(model_id) prediction = classifier(text) return self.parse_output(prediction) # For standalone testing if __name__ == "__main__": # Create an instance of the SentimentAnalysisTool class sentiment_analysis_tool = SentimentAnalysisTool() # Test with a sample text test_text = "I really enjoyed this product. It exceeded my expectations!" result = sentiment_analysis_tool(test_text) print(f"Input: {test_text}") print(f"Result: {result}")