Chris4K commited on
Commit
b04682c
·
verified ·
1 Parent(s): 7d20aae

Update ner_tool.py

Browse files
Files changed (1) hide show
  1. ner_tool.py +241 -45
ner_tool.py CHANGED
@@ -1,57 +1,253 @@
1
- # Updated NamedEntityRecognitionTool in ner_tool.py
2
-
3
- from transformers import pipeline
4
- from transformers import Tool
5
 
6
  class NamedEntityRecognitionTool(Tool):
7
  name = "ner_tool"
8
- description = "Identifies and labels various entities in a given text."
9
- inputs = ["text"]
10
- outputs = ["text"]
11
-
12
- def __call__(self, text: str):
13
- # Initialize the named entity recognition pipeline
14
- ner_analyzer = pipeline("ner")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Perform named entity recognition on the input text
17
- entities = ner_analyzer(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Prepare a list to store word-level entities
20
- word_entities = []
21
-
22
- # Initialize variables to track the current word and its label
23
- current_word = ""
24
- current_label = None
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  for entity in entities:
 
27
  label = entity.get("entity", "UNKNOWN")
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  word = entity.get("word", "")
29
- start = entity.get("start", -1)
30
- end = entity.get("end", -1)
31
-
32
- # Extract the complete entity text
33
- entity_text = text[start:end].strip()
34
-
35
- # Check for multi-token entities
36
- if "##" in word:
37
- # Concatenate sub-tokens to form the complete word
38
- current_word += entity_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  current_label = label
 
40
  else:
41
- # If it's the first token of a new word, add the previous word to the list
42
- if current_word:
43
- word_entities.append({"word": current_word, "label": current_label, "entity_text": current_word})
44
- current_word = ""
45
- current_label = None
46
-
47
- # Add the current token as a new word
48
- word_entities.append({"word": word, "label": label, "entity_text": entity_text})
49
-
50
- # Check for any remaining word
51
- if current_word:
52
- word_entities.append({"word": current_word, "label": current_label, "entity_text": current_word})
53
-
54
- # Print the identified word-level entities
55
- print(f"Word-level Entities: {word_entities}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- return {"entities": word_entities} # Return a dictionary with the specified output component
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Any, Optional, Union
3
+ from smolagents import Tool
 
4
 
5
  class NamedEntityRecognitionTool(Tool):
6
  name = "ner_tool"
7
+ description = """
8
+ Identifies and labels named entities in text using customizable NER models.
9
+ Can recognize entities such as persons, organizations, locations, dates, etc.
10
+ Returns a structured analysis of all entities found in the input text.
11
+ """
12
+ inputs = {
13
+ "text": {
14
+ "type": "string",
15
+ "description": "The text to analyze for named entities",
16
+ },
17
+ "model": {
18
+ "type": "string",
19
+ "description": "The NER model to use (default: 'dslim/bert-base-NER')",
20
+ "nullable": True
21
+ },
22
+ "aggregation": {
23
+ "type": "string",
24
+ "description": "How to aggregate entities: 'simple' (just list), 'grouped' (by label), or 'detailed' (with confidence scores)",
25
+ "nullable": True
26
+ },
27
+ "min_score": {
28
+ "type": "number",
29
+ "description": "Minimum confidence score threshold (0.0-1.0) for including entities",
30
+ "nullable": True
31
+ }
32
+ }
33
+ output_type = "string"
34
+
35
+ def __init__(self):
36
+ """Initialize the NER Tool with default settings."""
37
+ super().__init__()
38
+ self.default_model = "dslim/bert-base-NER"
39
+ self.available_models = {
40
+ "dslim/bert-base-NER": "Standard NER (English)",
41
+ "jean-baptiste/camembert-ner": "French NER",
42
+ "Davlan/bert-base-multilingual-cased-ner-hrl": "Multilingual NER",
43
+ "Babelscape/wikineural-multilingual-ner": "WikiNeural Multilingual NER",
44
+ "flair/ner-english-ontonotes-large": "OntoNotes English (fine-grained)",
45
+ "elastic/distilbert-base-cased-finetuned-conll03-english": "CoNLL (fast)"
46
+ }
47
+ self.entity_colors = {
48
+ "PER": "🟥 Person",
49
+ "PERSON": "🟥 Person",
50
+ "LOC": "🟨 Location",
51
+ "LOCATION": "🟨 Location",
52
+ "GPE": "🟨 Location",
53
+ "ORG": "🟦 Organization",
54
+ "ORGANIZATION": "🟦 Organization",
55
+ "MISC": "🟩 Miscellaneous",
56
+ "DATE": "🟪 Date",
57
+ "TIME": "🟪 Time",
58
+ "MONEY": "💰 Money",
59
+ "PERCENT": "📊 Percentage",
60
+ "PRODUCT": "🛒 Product",
61
+ "EVENT": "🎫 Event",
62
+ "WORK_OF_ART": "🎨 Work of Art",
63
+ "LAW": "⚖️ Law",
64
+ "LANGUAGE": "🗣️ Language",
65
+ "FAC": "🏢 Facility"
66
+ }
67
+ # Pipeline will be lazily loaded
68
+ self._pipeline = None
69
 
70
+ def _load_pipeline(self, model_name: str):
71
+ """Load the NER pipeline with the specified model."""
72
+ try:
73
+ from transformers import pipeline
74
+ self._pipeline = pipeline("ner", model=model_name, aggregation_strategy="simple")
75
+ return True
76
+ except Exception as e:
77
+ print(f"Error loading model {model_name}: {str(e)}")
78
+ try:
79
+ # Fall back to default model
80
+ from transformers import pipeline
81
+ self._pipeline = pipeline("ner", model=self.default_model, aggregation_strategy="simple")
82
+ return True
83
+ except Exception as fallback_error:
84
+ print(f"Error loading fallback model: {str(fallback_error)}")
85
+ return False
86
 
87
+ def _get_friendly_label(self, label: str) -> str:
88
+ """Convert technical entity labels to friendly descriptions with color indicators."""
89
+ # Strip B- or I- prefixes that indicate beginning or inside of entity
90
+ clean_label = label.replace("B-", "").replace("I-", "")
91
+ return self.entity_colors.get(clean_label, f"🔷 {clean_label}")
 
92
 
93
+ def forward(self, text: str, model: str = None, aggregation: str = None, min_score: float = None) -> str:
94
+ """
95
+ Perform Named Entity Recognition on the input text.
96
+
97
+ Args:
98
+ text: The text to analyze
99
+ model: NER model to use (default: dslim/bert-base-NER)
100
+ aggregation: How to aggregate results (simple, grouped, detailed)
101
+ min_score: Minimum confidence threshold (0.0-1.0)
102
+
103
+ Returns:
104
+ Formatted string with NER analysis results
105
+ """
106
+ # Set default values if parameters are None
107
+ if model is None:
108
+ model = self.default_model
109
+ if aggregation is None:
110
+ aggregation = "grouped"
111
+ if min_score is None:
112
+ min_score = 0.8
113
+
114
+ # Validate model choice
115
+ if model not in self.available_models and not model.startswith("dslim/"):
116
+ return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}"
117
+
118
+ # Load the model if not already loaded or if different from current
119
+ if self._pipeline is None or self._pipeline.model.name_or_path != model:
120
+ if not self._load_pipeline(model):
121
+ return "Failed to load NER model. Please try a different model."
122
+
123
+ # Perform NER analysis
124
+ try:
125
+ entities = self._pipeline(text)
126
+
127
+ # Filter by confidence score
128
+ entities = [e for e in entities if e.get('score', 0) >= min_score]
129
+
130
+ if not entities:
131
+ return "No entities were detected in the text with the current settings."
132
+
133
+ # Process results based on aggregation method
134
+ if aggregation == "simple":
135
+ return self._format_simple(text, entities)
136
+ elif aggregation == "detailed":
137
+ return self._format_detailed(text, entities)
138
+ else: # default to grouped
139
+ return self._format_grouped(text, entities)
140
+
141
+ except Exception as e:
142
+ return f"Error analyzing text: {str(e)}"
143
+
144
+ def _format_simple(self, text: str, entities: List[Dict[str, Any]]) -> str:
145
+ """Format entities as a simple list."""
146
+ result = "Named Entities Found:\n\n"
147
+
148
  for entity in entities:
149
+ word = entity.get("word", "")
150
  label = entity.get("entity", "UNKNOWN")
151
+ score = entity.get("score", 0)
152
+ friendly_label = self._get_friendly_label(label)
153
+
154
+ result += f"• {word} - {friendly_label} (confidence: {score:.2f})\n"
155
+
156
+ return result
157
+
158
+ def _format_grouped(self, text: str, entities: List[Dict[str, Any]]) -> str:
159
+ """Format entities grouped by their category."""
160
+ # Group entities by their label
161
+ grouped = {}
162
+
163
+ for entity in entities:
164
  word = entity.get("word", "")
165
+ label = entity.get("entity", "UNKNOWN").replace("B-", "").replace("I-", "")
166
+
167
+ if label not in grouped:
168
+ grouped[label] = []
169
+
170
+ grouped[label].append(word)
171
+
172
+ # Build the result string
173
+ result = "Named Entities by Category:\n\n"
174
+
175
+ for label, words in grouped.items():
176
+ friendly_label = self._get_friendly_label(label)
177
+ unique_words = list(set(words))
178
+ result += f"{friendly_label}: {', '.join(unique_words)}\n"
179
+
180
+ return result
181
+
182
+ def _format_detailed(self, text: str, entities: List[Dict[str, Any]]) -> str:
183
+ """Format entities with detailed information including position in text."""
184
+ # First, build an entity map to highlight the entire text
185
+ character_labels = [None] * len(text)
186
+
187
+ # Mark each character with its entity
188
+ for entity in entities:
189
+ start = entity.get("start", 0)
190
+ end = entity.get("end", 0)
191
+ label = entity.get("entity", "UNKNOWN")
192
+
193
+ for i in range(start, min(end, len(text))):
194
+ character_labels[i] = label
195
+
196
+ # Build highlighted text sections
197
+ highlighted_text = ""
198
+ current_label = None
199
+ current_segment = ""
200
+
201
+ for i, char in enumerate(text):
202
+ label = character_labels[i]
203
+
204
+ if label != current_label:
205
+ # End the previous segment if any
206
+ if current_segment:
207
+ if current_label:
208
+ clean_label = current_label.replace("B-", "").replace("I-", "")
209
+ highlighted_text += f"[{current_segment}]({clean_label}) "
210
+ else:
211
+ highlighted_text += current_segment + " "
212
+
213
+ # Start a new segment
214
  current_label = label
215
+ current_segment = char
216
  else:
217
+ current_segment += char
218
+
219
+ # Add the final segment
220
+ if current_segment:
221
+ if current_label:
222
+ clean_label = current_label.replace("B-", "").replace("I-", "")
223
+ highlighted_text += f"[{current_segment}]({clean_label})"
224
+ else:
225
+ highlighted_text += current_segment
226
+
227
+ # Get entity details
228
+ entity_details = []
229
+ for entity in entities:
230
+ word = entity.get("word", "")
231
+ label = entity.get("entity", "UNKNOWN")
232
+ score = entity.get("score", 0)
233
+ friendly_label = self._get_friendly_label(label)
234
+
235
+ entity_details.append(f"• {word} - {friendly_label} (confidence: {score:.2f})")
236
+
237
+ # Combine into final result
238
+ result = "Entity Analysis:\n\n"
239
+ result += "Text with Entities Marked:\n"
240
+ result += highlighted_text + "\n\n"
241
+ result += "Entity Details:\n"
242
+ result += "\n".join(entity_details)
243
+
244
+ return result
245
+
246
+ def get_available_models(self) -> Dict[str, str]:
247
+ """Return the dictionary of available models with descriptions."""
248
+ return self.available_models
249
 
250
+ # Example usage:
251
+ # ner_tool = NamedEntityRecognitionTool()
252
+ # result = ner_tool("Apple Inc. is planning to open a new store in Paris, France next year.", model="dslim/bert-base-NER")
253
+ # print(result)