Chris4K commited on
Commit
7b89327
·
verified ·
1 Parent(s): 0065a7e

Update sentiment_analysis.py

Browse files
Files changed (1) hide show
  1. sentiment_analysis.py +46 -22
sentiment_analysis.py CHANGED
@@ -1,4 +1,3 @@
1
- import requests
2
  import gradio as gr
3
  from transformers import pipeline
4
  from transformers import Tool
@@ -7,21 +6,33 @@ class SentimentAnalysisTool(Tool):
7
  name = "sentiment_analysis"
8
  description = "This tool analyses the sentiment of a given text."
9
 
10
- inputs = ["text"] # Adding an empty list for inputs
11
  outputs = ["json"]
12
 
13
- model_id_1 = "nlptown/bert-base-multilingual-uncased-sentiment"
14
- model_id_2 = "microsoft/deberta-xlarge-mnli"
15
- model_id_3 = "distilbert-base-uncased-finetuned-sst-2-english"
16
- model_id_4 = "lordtt13/emo-mobilebert"
17
- model_id_5 = "juliensimon/reviews-sentiment-analysis"
18
- model_id_6 = "sbcBI/sentiment_analysis_model"
19
- model_id_7 = "models/oliverguhr/german-sentiment-bert"
 
 
 
 
 
 
 
 
 
 
20
 
21
  def __call__(self, text: str):
22
- return self.predicto(text)
 
23
 
24
  def parse_output(self, output_json):
 
25
  list_pred = []
26
  for i in range(len(output_json[0])):
27
  label = output_json[0][i]['label']
@@ -29,18 +40,31 @@ class SentimentAnalysisTool(Tool):
29
  list_pred.append((label, score))
30
  return list_pred
31
 
32
- def get_prediction(self, model_id):
33
- classifier = pipeline("text-classification", model=model_id, return_all_scores=True)
34
- return classifier
 
 
 
 
 
 
35
 
36
- def predicto(self, review):
37
- classifier = self.get_prediction(self.model_id_3)
38
- prediction = classifier(review)
39
- print(prediction)
 
 
40
  return self.parse_output(prediction)
41
 
42
- # Create an instance of the SentimentAnalysisTool class
43
- sentiment_analysis_tool = SentimentAnalysisTool()
44
-
45
- # Create the Gradio interface
46
- #gr.Interface(fn=sentiment_analysis_tool, inputs=sentiment_analysis_tool.inputs, outputs=sentiment_analysis_tool.outputs).launch(share=True)
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  from transformers import Tool
 
6
  name = "sentiment_analysis"
7
  description = "This tool analyses the sentiment of a given text."
8
 
9
+ inputs = ["text"]
10
  outputs = ["json"]
11
 
12
+ # Available sentiment analysis models
13
+ models = {
14
+ "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
15
+ "deberta": "microsoft/deberta-xlarge-mnli",
16
+ "distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
17
+ "mobilebert": "lordtt13/emo-mobilebert",
18
+ "reviews": "juliensimon/reviews-sentiment-analysis",
19
+ "sbc": "sbcBI/sentiment_analysis_model",
20
+ "german": "oliverguhr/german-sentiment-bert"
21
+ }
22
+
23
+ def __init__(self, default_model="distilbert"):
24
+ """Initialize with a default model."""
25
+ self.default_model = default_model
26
+ # Pre-load the default model to speed up first inference
27
+ self._classifiers = {}
28
+ self.get_classifier(self.models[default_model])
29
 
30
  def __call__(self, text: str):
31
+ """Process input text and return sentiment predictions."""
32
+ return self.predict(text)
33
 
34
  def parse_output(self, output_json):
35
+ """Parse model output into a list of (label, score) tuples."""
36
  list_pred = []
37
  for i in range(len(output_json[0])):
38
  label = output_json[0][i]['label']
 
40
  list_pred.append((label, score))
41
  return list_pred
42
 
43
+ def get_classifier(self, model_id):
44
+ """Get or create a classifier for the given model ID."""
45
+ if model_id not in self._classifiers:
46
+ self._classifiers[model_id] = pipeline(
47
+ "text-classification",
48
+ model=model_id,
49
+ return_all_scores=True
50
+ )
51
+ return self._classifiers[model_id]
52
 
53
+ def predict(self, text, model_key=None):
54
+ """Make predictions using the specified or default model."""
55
+ model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
56
+ classifier = self.get_classifier(model_id)
57
+
58
+ prediction = classifier(text)
59
  return self.parse_output(prediction)
60
 
61
+ # For standalone testing
62
+ if __name__ == "__main__":
63
+ # Create an instance of the SentimentAnalysisTool class
64
+ sentiment_analysis_tool = SentimentAnalysisTool()
65
+
66
+ # Test with a sample text
67
+ test_text = "I really enjoyed this product. It exceeded my expectations!"
68
+ result = sentiment_analysis_tool(test_text)
69
+ print(f"Input: {test_text}")
70
+ print(f"Result: {result}")