Chris4K commited on
Commit
47f397e
·
verified ·
1 Parent(s): fd41f0a

Update sentiment_analysis.py

Browse files

Upgrade to smolagents.
Based on: https://huggingface.co/docs/smolagents/tutorials/tools

Files changed (1) hide show
  1. sentiment_analysis.py +16 -10
sentiment_analysis.py CHANGED
@@ -1,13 +1,18 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
- from transformers import Tool
4
 
5
  class SentimentAnalysisTool(Tool):
6
  name = "sentiment_analysis"
7
  description = "This tool analyses the sentiment of a given text."
8
 
9
- inputs = {"text": {"type": "text", "description": "The text to analyze"}}
10
- outputs = {"json": {"type": "json", "description": "Sentiment analysis results"}}
 
 
 
 
 
11
 
12
  # Available sentiment analysis models
13
  models = {
@@ -22,16 +27,17 @@ class SentimentAnalysisTool(Tool):
22
 
23
  def __init__(self, default_model="distilbert"):
24
  """Initialize with a default model."""
 
25
  self.default_model = default_model
26
  # Pre-load the default model to speed up first inference
27
  self._classifiers = {}
28
- self.get_classifier(self.models[default_model])
29
 
30
- def __call__(self, text: str):
31
  """Process input text and return sentiment predictions."""
32
  return self.predict(text)
33
 
34
- def parse_output(self, output_json):
35
  """Parse model output into a list of (label, score) tuples."""
36
  list_pred = []
37
  for i in range(len(output_json[0])):
@@ -40,23 +46,23 @@ class SentimentAnalysisTool(Tool):
40
  list_pred.append((label, score))
41
  return list_pred
42
 
43
- def get_classifier(self, model_id):
44
  """Get or create a classifier for the given model ID."""
45
  if model_id not in self._classifiers:
46
  self._classifiers[model_id] = pipeline(
47
  "text-classification",
48
  model=model_id,
49
- return_all_scores=True
50
  )
51
  return self._classifiers[model_id]
52
 
53
  def predict(self, text, model_key=None):
54
  """Make predictions using the specified or default model."""
55
  model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
56
- classifier = self.get_classifier(model_id)
57
 
58
  prediction = classifier(text)
59
- return self.parse_output(prediction)
60
 
61
  # For standalone testing
62
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ from smolagents import Tool
4
 
5
  class SentimentAnalysisTool(Tool):
6
  name = "sentiment_analysis"
7
  description = "This tool analyses the sentiment of a given text."
8
 
9
+ inputs = {
10
+ "text": {
11
+ "type": "string",
12
+ "description": "The text to analyze for sentiment"
13
+ }
14
+ }
15
+ output_type = "list"
16
 
17
  # Available sentiment analysis models
18
  models = {
 
27
 
28
  def __init__(self, default_model="distilbert"):
29
  """Initialize with a default model."""
30
+ super().__init__()
31
  self.default_model = default_model
32
  # Pre-load the default model to speed up first inference
33
  self._classifiers = {}
34
+ self._get_classifier(self.models[default_model])
35
 
36
+ def forward(self, text: str):
37
  """Process input text and return sentiment predictions."""
38
  return self.predict(text)
39
 
40
+ def _parse_output(self, output_json):
41
  """Parse model output into a list of (label, score) tuples."""
42
  list_pred = []
43
  for i in range(len(output_json[0])):
 
46
  list_pred.append((label, score))
47
  return list_pred
48
 
49
+ def _get_classifier(self, model_id):
50
  """Get or create a classifier for the given model ID."""
51
  if model_id not in self._classifiers:
52
  self._classifiers[model_id] = pipeline(
53
  "text-classification",
54
  model=model_id,
55
+ top_k=None # This replaces return_all_scores=True
56
  )
57
  return self._classifiers[model_id]
58
 
59
  def predict(self, text, model_key=None):
60
  """Make predictions using the specified or default model."""
61
  model_id = self.models[model_key] if model_key in self.models else self.models[self.default_model]
62
+ classifier = self._get_classifier(model_id)
63
 
64
  prediction = classifier(text)
65
+ return self._parse_output(prediction)
66
 
67
  # For standalone testing
68
  if __name__ == "__main__":