Chris4K commited on
Commit
d4fdfec
·
verified ·
1 Parent(s): a49896f

Update sentiment_analysis.py

Browse files
Files changed (1) hide show
  1. sentiment_analysis.py +61 -42
sentiment_analysis.py CHANGED
@@ -2,17 +2,23 @@ import gradio as gr
2
  from transformers import pipeline
3
  from smolagents import Tool
4
 
5
- class SentimentAnalysisTool(Tool):
6
  name = "sentiment_analysis"
7
- description = "This tool analyses the sentiment of a given text."
8
 
9
  inputs = {
10
  "text": {
11
  "type": "string",
12
  "description": "The text to analyze for sentiment"
 
 
 
 
 
13
  }
14
  }
15
- output_type = "json"
 
16
 
17
  # Available sentiment analysis models
18
  models = {
@@ -25,52 +31,65 @@ class SentimentAnalysisTool(Tool):
25
  "german": "oliverguhr/german-sentiment-bert"
26
  }
27
 
28
- def __init__(self, default_model="distilbert"):
29
- """Initialize with a default model."""
 
 
 
 
 
30
  super().__init__()
31
  self.default_model = default_model
32
- # Pre-load the default model to speed up first inference
33
  self._classifiers = {}
34
- self._get_classifier(self.models[default_model])
35
-
36
- def forward(self, text: str):
37
- """Process input text and return sentiment predictions."""
38
- return self.predict(text)
39
 
40
- def _parse_output(self, output_json):
41
- """Parse model output into a dictionary of scores by label."""
42
- result = {}
43
- for i in range(len(output_json[0])):
44
- label = output_json[0][i]['label']
45
- score = output_json[0][i]['score']
46
- result[label] = score
47
- return result
48
 
49
  def _get_classifier(self, model_id):
50
  """Get or create a classifier for the given model ID."""
51
  if model_id not in self._classifiers:
52
- self._classifiers[model_id] = pipeline(
53
- "text-classification",
54
- model=model_id,
55
- top_k=None # This replaces return_all_scores=True
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
57
  return self._classifiers[model_id]
58
 
59
- def predict(self, text, model_key=None):
60
- """Make predictions using the specified or default model."""
61
- model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
62
- classifier = self._get_classifier(model_id)
63
-
64
- prediction = classifier(text)
65
- return self._parse_output(prediction)
66
-
67
- # For standalone testing
68
- if __name__ == "__main__":
69
- # Create an instance of the SentimentAnalysisTool class
70
- sentiment_analysis_tool = SentimentAnalysisTool()
71
-
72
- # Test with a sample text
73
- test_text = "I really enjoyed this product. It exceeded my expectations!"
74
- result = sentiment_analysis_tool(test_text)
75
- print(f"Input: {test_text}")
76
- print(f"Result: {result}")
 
 
 
 
 
2
  from transformers import pipeline
3
  from smolagents import Tool
4
 
5
+ class SimpleSentimentTool(Tool):
6
  name = "sentiment_analysis"
7
+ description = "This tool analyzes the sentiment of a given text."
8
 
9
  inputs = {
10
  "text": {
11
  "type": "string",
12
  "description": "The text to analyze for sentiment"
13
+ },
14
+ "model_key": {
15
+ "type": "string",
16
+ "description": "The model to use for sentiment analysis",
17
+ "default": None
18
  }
19
  }
20
+ # Use a standard authorized type
21
+ output_type = "dict[str, float]"
22
 
23
  # Available sentiment analysis models
24
  models = {
 
31
  "german": "oliverguhr/german-sentiment-bert"
32
  }
33
 
34
+ def __init__(self, default_model="distilbert", preload=False):
35
+ """Initialize with a default model.
36
+
37
+ Args:
38
+ default_model: The default model to use if no model is specified
39
+ preload: Whether to preload the default model at initialization
40
+ """
41
  super().__init__()
42
  self.default_model = default_model
 
43
  self._classifiers = {}
 
 
 
 
 
44
 
45
+ # Optionally preload the default model
46
+ if preload:
47
+ try:
48
+ self._get_classifier(self.models[default_model])
49
+ except Exception as e:
50
+ print(f"Warning: Failed to preload model: {str(e)}")
 
 
51
 
52
  def _get_classifier(self, model_id):
53
  """Get or create a classifier for the given model ID."""
54
  if model_id not in self._classifiers:
55
+ try:
56
+ print(f"Loading model: {model_id}")
57
+ self._classifiers[model_id] = pipeline(
58
+ "text-classification",
59
+ model=model_id,
60
+ top_k=None # Return all scores
61
+ )
62
+ except Exception as e:
63
+ print(f"Error loading model {model_id}: {str(e)}")
64
+ # Fall back to distilbert if available
65
+ if model_id != self.models["distilbert"]:
66
+ print("Falling back to distilbert model...")
67
+ return self._get_classifier(self.models["distilbert"])
68
+ else:
69
+ # Last resort - if even distilbert fails
70
+ print("Critical error: Could not load default model")
71
+ raise RuntimeError(f"Failed to load any sentiment model: {str(e)}")
72
  return self._classifiers[model_id]
73
 
74
+ def forward(self, text: str, model_key=None):
75
+ """Process input text and return sentiment predictions."""
76
+ try:
77
+ # Determine which model to use
78
+ model_key = model_key or self.default_model
79
+ model_id = self.models.get(model_key, self.models[self.default_model])
80
+
81
+ # Get the classifier
82
+ classifier = self._get_classifier(model_id)
83
+
84
+ # Get predictions
85
+ prediction = classifier(text)
86
+
87
+ # Format as a dictionary
88
+ result = {}
89
+ for item in prediction[0]:
90
+ result[item['label']] = float(item['score'])
91
+
92
+ return result
93
+ except Exception as e:
94
+ print(f"Error in sentiment analysis: {str(e)}")
95
+ return {"error": str(e)}