Chris4K commited on
Commit
7d26c0c
·
verified ·
1 Parent(s): 2113210

Update ner_tool.py

Browse files
Files changed (1) hide show
  1. ner_tool.py +168 -7
ner_tool.py CHANGED
@@ -62,7 +62,10 @@ class NamedEntityRecognitionTool(Tool):
62
  "WORK_OF_ART": "🎨 Work of Art",
63
  "LAW": "⚖️ Law",
64
  "LANGUAGE": "🗣️ Language",
65
- "FAC": "🏢 Facility"
 
 
 
66
  }
67
  # Pipeline will be lazily loaded
68
  self._pipeline = None
@@ -71,14 +74,41 @@ class NamedEntityRecognitionTool(Tool):
71
  """Load the NER pipeline with the specified model."""
72
  try:
73
  from transformers import pipeline
74
- self._pipeline = pipeline("ner", model=model_name, aggregation_strategy="simple")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  return True
76
  except Exception as e:
77
  print(f"Error loading model {model_name}: {str(e)}")
78
  try:
79
  # Fall back to default model
80
  from transformers import pipeline
81
- self._pipeline = pipeline("ner", model=self.default_model, aggregation_strategy="simple")
 
 
 
 
 
 
 
82
  return True
83
  except Exception as fallback_error:
84
  print(f"Error loading fallback model: {str(fallback_error)}")
@@ -88,6 +118,34 @@ class NamedEntityRecognitionTool(Tool):
88
  """Convert technical entity labels to friendly descriptions with color indicators."""
89
  # Strip B- or I- prefixes that indicate beginning or inside of entity
90
  clean_label = label.replace("B-", "").replace("I-", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return self.entity_colors.get(clean_label, f"🔷 {clean_label}")
92
 
93
  def forward(self, text: str, model: str = None, aggregation: str = None, min_score: float = None) -> str:
@@ -127,6 +185,16 @@ class NamedEntityRecognitionTool(Tool):
127
  # Filter by confidence score
128
  entities = [e for e in entities if e.get('score', 0) >= min_score]
129
 
 
 
 
 
 
 
 
 
 
 
130
  if not entities:
131
  return "No entities were detected in the text with the current settings."
132
 
@@ -143,9 +211,40 @@ class NamedEntityRecognitionTool(Tool):
143
 
144
  def _format_simple(self, text: str, entities: List[Dict[str, Any]]) -> str:
145
  """Format entities as a simple list."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  result = "Named Entities Found:\n\n"
147
 
148
- for entity in entities:
149
  word = entity.get("word", "")
150
  label = entity.get("entity", "UNKNOWN")
151
  score = entity.get("score", 0)
@@ -157,10 +256,41 @@ class NamedEntityRecognitionTool(Tool):
157
 
158
  def _format_grouped(self, text: str, entities: List[Dict[str, Any]]) -> str:
159
  """Format entities grouped by their category."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # Group entities by their label
161
  grouped = {}
162
 
163
- for entity in entities:
164
  word = entity.get("word", "")
165
  label = entity.get("entity", "UNKNOWN").replace("B-", "").replace("I-", "")
166
 
@@ -181,11 +311,42 @@ class NamedEntityRecognitionTool(Tool):
181
 
182
  def _format_detailed(self, text: str, entities: List[Dict[str, Any]]) -> str:
183
  """Format entities with detailed information including position in text."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # First, build an entity map to highlight the entire text
185
  character_labels = [None] * len(text)
186
 
187
  # Mark each character with its entity
188
- for entity in entities:
189
  start = entity.get("start", 0)
190
  end = entity.get("end", 0)
191
  label = entity.get("entity", "UNKNOWN")
@@ -226,7 +387,7 @@ class NamedEntityRecognitionTool(Tool):
226
 
227
  # Get entity details
228
  entity_details = []
229
- for entity in entities:
230
  word = entity.get("word", "")
231
  label = entity.get("entity", "UNKNOWN")
232
  score = entity.get("score", 0)
 
62
  "WORK_OF_ART": "🎨 Work of Art",
63
  "LAW": "⚖️ Law",
64
  "LANGUAGE": "🗣️ Language",
65
+ "FAC": "🏢 Facility",
66
+ # Fix for models that don't properly tag entities
67
+ "O": "Not an entity",
68
+ "UNKNOWN": "🔷 Entity"
69
  }
70
  # Pipeline will be lazily loaded
71
  self._pipeline = None
 
74
  """Load the NER pipeline with the specified model."""
75
  try:
76
  from transformers import pipeline
77
+ import torch
78
+
79
+ # Try to detect if GPU is available
80
+ device = 0 if torch.cuda.is_available() else -1
81
+
82
+ # For some models, we need special handling
83
+ if "dslim/bert-base-NER" in model_name:
84
+ # This model works better with a specific aggregation strategy
85
+ self._pipeline = pipeline(
86
+ "ner",
87
+ model=model_name,
88
+ aggregation_strategy="first",
89
+ device=device
90
+ )
91
+ else:
92
+ self._pipeline = pipeline(
93
+ "ner",
94
+ model=model_name,
95
+ aggregation_strategy="simple",
96
+ device=device
97
+ )
98
  return True
99
  except Exception as e:
100
  print(f"Error loading model {model_name}: {str(e)}")
101
  try:
102
  # Fall back to default model
103
  from transformers import pipeline
104
+ import torch
105
+ device = 0 if torch.cuda.is_available() else -1
106
+ self._pipeline = pipeline(
107
+ "ner",
108
+ model=self.default_model,
109
+ aggregation_strategy="first",
110
+ device=device
111
+ )
112
  return True
113
  except Exception as fallback_error:
114
  print(f"Error loading fallback model: {str(fallback_error)}")
 
118
  """Convert technical entity labels to friendly descriptions with color indicators."""
119
  # Strip B- or I- prefixes that indicate beginning or inside of entity
120
  clean_label = label.replace("B-", "").replace("I-", "")
121
+
122
+ # Handle common name and location patterns with heuristics
123
+ if clean_label == "UNKNOWN" or clean_label == "O":
124
+ # Apply some basic heuristics to detect entity types
125
+ # This is a fallback when the model fails to properly tag
126
+ text = self._current_entity_text.lower() if hasattr(self, '_current_entity_text') else ""
127
+
128
+ # Check for capitalized words which might be names or places
129
+ if text and text[0].isupper():
130
+ # Countries and major cities
131
+ countries_and_cities = ["germany", "france", "spain", "italy", "london",
132
+ "paris", "berlin", "rome", "new york", "tokyo",
133
+ "beijing", "moscow", "canada", "australia", "india",
134
+ "china", "japan", "russia", "brazil", "mexico"]
135
+
136
+ if text.lower() in countries_and_cities:
137
+ return self.entity_colors.get("LOC", "🟨 Location")
138
+
139
+ # Common first names (add more as needed)
140
+ common_names = ["john", "mike", "sarah", "david", "michael", "james",
141
+ "robert", "mary", "jennifer", "linda", "michael", "william",
142
+ "kristof", "chris", "thomas", "daniel", "matthew", "joseph",
143
+ "donald", "richard", "charles", "paul", "mark", "kevin"]
144
+
145
+ name_parts = text.lower().split()
146
+ if name_parts and name_parts[0] in common_names:
147
+ return self.entity_colors.get("PER", "🟥 Person")
148
+
149
  return self.entity_colors.get(clean_label, f"🔷 {clean_label}")
150
 
151
  def forward(self, text: str, model: str = None, aggregation: str = None, min_score: float = None) -> str:
 
185
  # Filter by confidence score
186
  entities = [e for e in entities if e.get('score', 0) >= min_score]
187
 
188
+ # Store the text for better heuristics
189
+ for entity in entities:
190
+ word = entity.get("word", "")
191
+ start = entity.get("start", 0)
192
+ end = entity.get("end", 0)
193
+ # Store the actual text from the input for better entity type detection
194
+ entity['actual_text'] = text[start:end]
195
+ # Set this for _get_friendly_label to use
196
+ self._current_entity_text = text[start:end]
197
+
198
  if not entities:
199
  return "No entities were detected in the text with the current settings."
200
 
 
211
 
212
  def _format_simple(self, text: str, entities: List[Dict[str, Any]]) -> str:
213
  """Format entities as a simple list."""
214
+ # Process word pieces and handle subtoken merging
215
+ merged_entities = []
216
+ current_entity = None
217
+
218
+ for entity in sorted(entities, key=lambda e: e.get("start", 0)):
219
+ word = entity.get("word", "")
220
+ start = entity.get("start", 0)
221
+ end = entity.get("end", 0)
222
+ label = entity.get("entity", "UNKNOWN")
223
+ score = entity.get("score", 0)
224
+
225
+ # Check if this is a continuation (subtoken)
226
+ if word.startswith("##"):
227
+ if current_entity:
228
+ # Extend the current entity
229
+ current_entity["word"] += word.replace("##", "")
230
+ current_entity["end"] = end
231
+ # Keep the average score
232
+ current_entity["score"] = (current_entity["score"] + score) / 2
233
+ continue
234
+
235
+ # Start a new entity
236
+ current_entity = {
237
+ "word": word,
238
+ "start": start,
239
+ "end": end,
240
+ "entity": label,
241
+ "score": score
242
+ }
243
+ merged_entities.append(current_entity)
244
+
245
  result = "Named Entities Found:\n\n"
246
 
247
+ for entity in merged_entities:
248
  word = entity.get("word", "")
249
  label = entity.get("entity", "UNKNOWN")
250
  score = entity.get("score", 0)
 
256
 
257
  def _format_grouped(self, text: str, entities: List[Dict[str, Any]]) -> str:
258
  """Format entities grouped by their category."""
259
+ # Process word pieces and handle subtoken merging
260
+ merged_entities = []
261
+ current_entity = None
262
+
263
+ for entity in sorted(entities, key=lambda e: e.get("start", 0)):
264
+ word = entity.get("word", "")
265
+ start = entity.get("start", 0)
266
+ end = entity.get("end", 0)
267
+ label = entity.get("entity", "UNKNOWN")
268
+ score = entity.get("score", 0)
269
+
270
+ # Check if this is a continuation (subtoken)
271
+ if word.startswith("##"):
272
+ if current_entity:
273
+ # Extend the current entity
274
+ current_entity["word"] += word.replace("##", "")
275
+ current_entity["end"] = end
276
+ # Keep the average score
277
+ current_entity["score"] = (current_entity["score"] + score) / 2
278
+ continue
279
+
280
+ # Start a new entity
281
+ current_entity = {
282
+ "word": word,
283
+ "start": start,
284
+ "end": end,
285
+ "entity": label,
286
+ "score": score
287
+ }
288
+ merged_entities.append(current_entity)
289
+
290
  # Group entities by their label
291
  grouped = {}
292
 
293
+ for entity in merged_entities:
294
  word = entity.get("word", "")
295
  label = entity.get("entity", "UNKNOWN").replace("B-", "").replace("I-", "")
296
 
 
311
 
312
  def _format_detailed(self, text: str, entities: List[Dict[str, Any]]) -> str:
313
  """Format entities with detailed information including position in text."""
314
+ # Process word pieces and handle subtoken merging
315
+ merged_entities = []
316
+ current_entity = None
317
+
318
+ for entity in sorted(entities, key=lambda e: e.get("start", 0)):
319
+ word = entity.get("word", "")
320
+ start = entity.get("start", 0)
321
+ end = entity.get("end", 0)
322
+ label = entity.get("entity", "UNKNOWN")
323
+ score = entity.get("score", 0)
324
+
325
+ # Check if this is a continuation (subtoken)
326
+ if word.startswith("##"):
327
+ if current_entity:
328
+ # Extend the current entity
329
+ current_entity["word"] += word.replace("##", "")
330
+ current_entity["end"] = end
331
+ # Keep the average score
332
+ current_entity["score"] = (current_entity["score"] + score) / 2
333
+ continue
334
+
335
+ # Start a new entity
336
+ current_entity = {
337
+ "word": word,
338
+ "start": start,
339
+ "end": end,
340
+ "entity": label,
341
+ "score": score
342
+ }
343
+ merged_entities.append(current_entity)
344
+
345
  # First, build an entity map to highlight the entire text
346
  character_labels = [None] * len(text)
347
 
348
  # Mark each character with its entity
349
+ for entity in merged_entities:
350
  start = entity.get("start", 0)
351
  end = entity.get("end", 0)
352
  label = entity.get("entity", "UNKNOWN")
 
387
 
388
  # Get entity details
389
  entity_details = []
390
+ for entity in merged_entities:
391
  word = entity.get("word", "")
392
  label = entity.get("entity", "UNKNOWN")
393
  score = entity.get("score", 0)