Chris4K commited on
Commit
3b1ba42
·
verified ·
1 Parent(s): 8ca215e

Rename simple_sentiment.py to summarizer_tool.py

Browse files
Files changed (2) hide show
  1. simple_sentiment.py +0 -98
  2. summarizer_tool.py +191 -0
simple_sentiment.py DELETED
@@ -1,98 +0,0 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- from smolagents import Tool
4
-
5
- class SimpleSentimentTool(Tool):
6
- name = "sentiment_analysis"
7
- description = "This tool analyzes the sentiment of a given text."
8
-
9
- inputs = {
10
- "text": {
11
- "type": "string",
12
- "description": "The text to analyze for sentiment"
13
- },
14
- "model_key": {
15
- "type": "string",
16
- "description": "The model to use for sentiment analysis",
17
- "default": "oliverguhr/german-sentiment-bert",
18
- "nullable": True
19
- }
20
- }
21
- # Use a standard authorized type
22
- output_type = "string"
23
-
24
- # Available sentiment analysis models
25
- models = {
26
- "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
27
- "deberta": "microsoft/deberta-xlarge-mnli",
28
- "distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
29
- "mobilebert": "lordtt13/emo-mobilebert",
30
- "reviews": "juliensimon/reviews-sentiment-analysis",
31
- "sbc": "sbcBI/sentiment_analysis_model",
32
- "german": "oliverguhr/german-sentiment-bert"
33
- }
34
-
35
- def __init__(self, default_model="distilbert", preload=False):
36
- """Initialize with a default model.
37
-
38
- Args:
39
- default_model: The default model to use if no model is specified
40
- preload: Whether to preload the default model at initialization
41
- """
42
- super().__init__()
43
- self.default_model = default_model
44
- self._classifiers = {}
45
-
46
- # Optionally preload the default model
47
- if preload:
48
- try:
49
- self._get_classifier(self.models[default_model])
50
- except Exception as e:
51
- print(f"Warning: Failed to preload model: {str(e)}")
52
-
53
- def _get_classifier(self, model_id):
54
- """Get or create a classifier for the given model ID."""
55
- if model_id not in self._classifiers:
56
- try:
57
- print(f"Loading model: {model_id}")
58
- self._classifiers[model_id] = pipeline(
59
- "text-classification",
60
- model=model_id,
61
- top_k=None # Return all scores
62
- )
63
- except Exception as e:
64
- print(f"Error loading model {model_id}: {str(e)}")
65
- # Fall back to distilbert if available
66
- if model_id != self.models["distilbert"]:
67
- print("Falling back to distilbert model...")
68
- return self._get_classifier(self.models["distilbert"])
69
- else:
70
- # Last resort - if even distilbert fails
71
- print("Critical error: Could not load default model")
72
- raise RuntimeError(f"Failed to load any sentiment model: {str(e)}")
73
- return self._classifiers[model_id]
74
-
75
- def forward(self, text: str, model_key="oliverguhr/german-sentiment-bert"):
76
- """Process input text and return sentiment predictions."""
77
- try:
78
- # Determine which model to use
79
- model_key = model_key or self.default_model
80
- model_id = self.models.get(model_key, self.models[self.default_model])
81
-
82
- # Get the classifier
83
- classifier = self._get_classifier(model_id)
84
-
85
- # Get predictions
86
- prediction = classifier(text)
87
-
88
- # Format as a dictionary
89
- result = {}
90
- for item in prediction[0]:
91
- result[item['label']] = float(item['score'])
92
-
93
- # Convert to JSON string for output
94
- import json
95
- return json.dumps(result, indent=2)
96
- except Exception as e:
97
- print(f"Error in sentiment analysis: {str(e)}")
98
- return json.dumps({"error": str(e)}, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
summarizer_tool.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from typing import Dict, Any, Optional
3
+ import warnings
4
+
5
+ # Suppress unnecessary warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ class TextSummarizerTool(Tool):
9
+ name = "text_summarizer"
10
+ description = """
11
+ Summarizes text using various summarization methods and models.
12
+ This tool can generate concise summaries of longer texts while preserving key information.
13
+ It supports different summarization models and customizable parameters.
14
+ """
15
+ inputs = {
16
+ "text": {
17
+ "type": "string",
18
+ "description": "The text to be summarized",
19
+ },
20
+ "model": {
21
+ "type": "string",
22
+ "description": "Summarization model to use (default: 'facebook/bart-large-cnn')",
23
+ "nullable": True
24
+ },
25
+ "max_length": {
26
+ "type": "integer",
27
+ "description": "Maximum length of the summary in tokens (default: 130)",
28
+ "nullable": True
29
+ },
30
+ "min_length": {
31
+ "type": "integer",
32
+ "description": "Minimum length of the summary in tokens (default: 30)",
33
+ "nullable": True
34
+ },
35
+ "style": {
36
+ "type": "string",
37
+ "description": "Style of summary: 'concise', 'detailed', or 'bullet_points' (default: 'concise')",
38
+ "nullable": True
39
+ }
40
+ }
41
+ output_type = "string"
42
+
43
+ def __init__(self):
44
+ """Initialize the Text Summarizer Tool with default settings."""
45
+ super().__init__()
46
+ self.default_model = "facebook/bart-large-cnn"
47
+ self.available_models = {
48
+ "facebook/bart-large-cnn": "BART CNN (good for news)",
49
+ "sshleifer/distilbart-cnn-12-6": "DistilBART (faster, smaller)",
50
+ "google/pegasus-xsum": "Pegasus (extreme summarization)",
51
+ "facebook/bart-large-xsum": "BART XSum (very concise)",
52
+ "philschmid/bart-large-cnn-samsum": "BART SamSum (good for conversations)"
53
+ }
54
+ # Pipeline will be lazily loaded
55
+ self._pipeline = None
56
+
57
+ def _load_pipeline(self, model_name: str):
58
+ """Load the summarization pipeline with the specified model."""
59
+ try:
60
+ from transformers import pipeline
61
+ import torch
62
+
63
+ # Try to detect if GPU is available
64
+ device = 0 if torch.cuda.is_available() else -1
65
+
66
+ # Load the summarization pipeline
67
+ self._pipeline = pipeline(
68
+ "summarization",
69
+ model=model_name,
70
+ device=device
71
+ )
72
+ return True
73
+ except Exception as e:
74
+ print(f"Error loading model {model_name}: {str(e)}")
75
+ try:
76
+ # Fall back to default model
77
+ from transformers import pipeline
78
+ import torch
79
+ device = 0 if torch.cuda.is_available() else -1
80
+ self._pipeline = pipeline(
81
+ "summarization",
82
+ model=self.default_model,
83
+ device=device
84
+ )
85
+ return True
86
+ except Exception as fallback_error:
87
+ print(f"Error loading fallback model: {str(fallback_error)}")
88
+ return False
89
+
90
+ def _format_as_bullets(self, summary: str) -> str:
91
+ """Format a summary as bullet points."""
92
+ # Split the summary into sentences
93
+ import re
94
+ sentences = re.split(r'(?<=[.!?])\s+', summary)
95
+ sentences = [s.strip() for s in sentences if s.strip()]
96
+
97
+ # Format as bullet points
98
+ bullet_points = []
99
+ for sentence in sentences:
100
+ # Skip very short sentences that might be artifacts
101
+ if len(sentence) < 15:
102
+ continue
103
+ bullet_points.append(f"• {sentence}")
104
+
105
+ return "\n".join(bullet_points)
106
+
107
+ def forward(self, text: str, model: str = None, max_length: int = None, min_length: int = None, style: str = None) -> str:
108
+ """
109
+ Summarize the input text.
110
+
111
+ Args:
112
+ text: The text to summarize
113
+ model: Summarization model to use
114
+ max_length: Maximum summary length in tokens
115
+ min_length: Minimum summary length in tokens
116
+ style: Style of summary ('concise', 'detailed', or 'bullet_points')
117
+
118
+ Returns:
119
+ Summarized text
120
+ """
121
+ # Set default values if parameters are None
122
+ if model is None:
123
+ model = self.default_model
124
+ if max_length is None:
125
+ max_length = 130
126
+ if min_length is None:
127
+ min_length = 30
128
+ if style is None:
129
+ style = "concise"
130
+
131
+ # Validate model choice
132
+ if model not in self.available_models:
133
+ return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}"
134
+
135
+ # Load the model if not already loaded or if different from current
136
+ if self._pipeline is None or (hasattr(self._pipeline, 'model') and self._pipeline.model.name_or_path != model):
137
+ if not self._load_pipeline(model):
138
+ return "Failed to load summarization model. Please try a different model."
139
+
140
+ # Adjust parameters based on style
141
+ if style == "concise":
142
+ max_length = min(100, max_length)
143
+ min_length = min(30, min_length)
144
+ elif style == "detailed":
145
+ max_length = max(150, max_length)
146
+ min_length = max(50, min_length)
147
+
148
+ # Ensure text is not too short
149
+ if len(text.split()) < 20:
150
+ return "The input text is too short to summarize effectively."
151
+
152
+ # Perform summarization
153
+ try:
154
+ # Truncate very long inputs if needed (model dependent)
155
+ max_input_length = 1024 # Most models have limits around 1024-2048 tokens
156
+ words = text.split()
157
+ if len(words) > max_input_length:
158
+ text = " ".join(words[:max_input_length])
159
+ note = "\n\nNote: The input was truncated due to length limits."
160
+ else:
161
+ note = ""
162
+
163
+ summary = self._pipeline(
164
+ text,
165
+ max_length=max_length,
166
+ min_length=min_length,
167
+ do_sample=False
168
+ )
169
+
170
+ result = summary[0]['summary_text']
171
+
172
+ # Format the result based on style
173
+ if style == "bullet_points":
174
+ result = self._format_as_bullets(result)
175
+
176
+ # Add metadata
177
+ metadata = f"\n\nSummarized using: {self.available_models.get(model, model)}"
178
+
179
+ return result + metadata + note
180
+
181
+ except Exception as e:
182
+ return f"Error summarizing text: {str(e)}"
183
+
184
+ def get_available_models(self) -> Dict[str, str]:
185
+ """Return the dictionary of available models with descriptions."""
186
+ return self.available_models
187
+
188
+ # Example usage:
189
+ # summarizer = TextSummarizerTool()
190
+ # result = summarizer("Long text goes here...", model="facebook/bart-large-cnn", style="bullet_points")
191
+ # print(result)