File size: 2,801 Bytes
1c02c6e
 
e7e4d1a
c2f577f
e7e4d1a
c2f577f
5cb8034
c2f577f
fd41f0a
 
c2f577f
7b89327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc1edab
5cb8034
7b89327
 
17b66cc
 
7b89327
17b66cc
cc1edab
 
 
 
 
 
7b89327
 
 
 
 
 
 
 
 
17b66cc
7b89327
 
 
 
 
 
17b66cc
c2f577f
7b89327
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline
from transformers import Tool

class SentimentAnalysisTool(Tool):
    name = "sentiment_analysis"
    description = "This tool analyses the sentiment of a given text."

    inputs = {"text": {"type": "text", "description": "The text to analyze"}}
    outputs = {"json": {"type": "json", "description": "Sentiment analysis results"}}
    
    # Available sentiment analysis models
    models = {
        "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
        "deberta": "microsoft/deberta-xlarge-mnli",
        "distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
        "mobilebert": "lordtt13/emo-mobilebert",
        "reviews": "juliensimon/reviews-sentiment-analysis",
        "sbc": "sbcBI/sentiment_analysis_model",
        "german": "oliverguhr/german-sentiment-bert"
    }
    
    def __init__(self, default_model="distilbert"):
        """Initialize with a default model."""
        self.default_model = default_model
        # Pre-load the default model to speed up first inference
        self._classifiers = {}
        self.get_classifier(self.models[default_model])
    
    def __call__(self, text: str): 
        """Process input text and return sentiment predictions."""
        return self.predict(text)
        
    def parse_output(self, output_json):  
        """Parse model output into a list of (label, score) tuples."""
        list_pred = []
        for i in range(len(output_json[0])):
            label = output_json[0][i]['label']
            score = output_json[0][i]['score']
            list_pred.append((label, score))
        return list_pred
    
    def get_classifier(self, model_id):
        """Get or create a classifier for the given model ID."""
        if model_id not in self._classifiers:
            self._classifiers[model_id] = pipeline(
                "text-classification", 
                model=model_id, 
                return_all_scores=True
            )
        return self._classifiers[model_id]
    
    def predict(self, text, model_key=None):
        """Make predictions using the specified or default model."""
        model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
        classifier = self.get_classifier(model_id)
        
        prediction = classifier(text)
        return self.parse_output(prediction)

# For standalone testing
if __name__ == "__main__":
    # Create an instance of the SentimentAnalysisTool class
    sentiment_analysis_tool = SentimentAnalysisTool()
    
    # Test with a sample text
    test_text = "I really enjoyed this product. It exceeded my expectations!"
    result = sentiment_analysis_tool(test_text)
    print(f"Input: {test_text}")
    print(f"Result: {result}")