Abhigyan commited on
Commit
e3f321e
·
1 Parent(s): b99eab6

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +396 -0
app.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ner_module.py
2
+ import torch
3
+ import time
4
+ from typing import List, Dict, Any, Tuple
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class NERModel:
13
+ """
14
+ A singleton class to manage the NER model loading and prediction.
15
+ Ensures the potentially large model is loaded only once.
16
+ """
17
+ _instance = None
18
+ _model = None
19
+ _tokenizer = None
20
+ _pipeline = None
21
+ _model_name = None # Store model name used for initialization
22
+
23
+ @classmethod
24
+ def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"):
25
+ """
26
+ Singleton pattern: Get the existing instance or create a new one.
27
+ Uses the specified model_name only during the first initialization.
28
+ """
29
+ if cls._instance is None:
30
+ logger.info(f"Creating new NERModel instance with model: {model_name}")
31
+ cls._instance = cls(model_name)
32
+ elif cls._model_name != model_name:
33
+ logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.")
34
+ return cls._instance
35
+
36
+ def __init__(self, model_name: str):
37
+ """
38
+ Initialize the model, tokenizer, and pipeline.
39
+ Private constructor - use get_instance() instead.
40
+ """
41
+ if NERModel._instance is not None:
42
+ raise Exception("This class is a singleton! Use get_instance() to get the object.")
43
+ else:
44
+ self.model_name = model_name
45
+ NERModel._model_name = model_name # Store the model name
46
+ self._load_model()
47
+ NERModel._instance = self # Assign the instance here
48
+
49
+ def _load_model(self):
50
+ """Load the NER model and tokenizer from Hugging Face."""
51
+ logger.info(f"Loading model: {self.model_name}")
52
+ start_time = time.time()
53
+
54
+ try:
55
+ # Load tokenizer and model
56
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
57
+ self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
58
+
59
+ # Check if the model is a PyTorch model for potential optimizations
60
+ if isinstance(self._model, torch.nn.Module):
61
+ self._model.eval() # Set model to evaluation mode (important for inference)
62
+ # self._model.share_memory() # share_memory() might not be needed unless using multiprocessing explicitly
63
+
64
+ # Create the NER pipeline
65
+ # Specify device=-1 for CPU, device=0 for first GPU, etc.
66
+ # Let pipeline decide device automatically by default, or specify if needed
67
+ self._pipeline = pipeline(
68
+ "ner",
69
+ model=self._model,
70
+ tokenizer=self._tokenizer,
71
+ # grouped_entities=True # Group subword tokens automatically (alternative to manual combination)
72
+ )
73
+
74
+ load_time = time.time() - start_time
75
+ logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")
76
+
77
+ except Exception as e:
78
+ logger.error(f"Error loading model {self.model_name}: {e}")
79
+ # Clean up partial loads if necessary
80
+ self._tokenizer = None
81
+ self._model = None
82
+ self._pipeline = None
83
+ # Re-raise the exception to signal failure
84
+ raise
85
+
86
+ def predict(self, text: str) -> List[Dict[str, Any]]:
87
+ """
88
+ Run NER prediction on the input text using the loaded pipeline.
89
+
90
+ Args:
91
+ text: The input string to perform NER on.
92
+
93
+ Returns:
94
+ A list of dictionaries, where each dictionary represents an entity
95
+ identified by the pipeline. The exact format depends on the pipeline
96
+ configuration (e.g., grouped_entities).
97
+ """
98
+ if self._pipeline is None:
99
+ logger.error("NER pipeline is not initialized. Cannot predict.")
100
+ return [] # Return empty list or raise an error
101
+
102
+ if not text or not isinstance(text, str):
103
+ logger.warning("Prediction called with empty or invalid text.")
104
+ return []
105
+
106
+ logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
107
+ try:
108
+ # The pipeline handles tokenization and prediction
109
+ results = self._pipeline(text)
110
+ logger.debug(f"Prediction results: {results}")
111
+ return results
112
+ except Exception as e:
113
+ logger.error(f"Error during NER prediction: {e}")
114
+ return [] # Return empty list on error
115
+
116
+
117
+ class TextProcessor:
118
+ """
119
+ Provides static methods for processing text, specifically for NER tasks,
120
+ including combining subword entities and handling large texts via chunking.
121
+ """
122
+
123
+ @staticmethod
124
+ def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]:
125
+ """
126
+ Combine entities that might be split into subword tokens (B-TAG, I-TAG).
127
+ This method assumes the pipeline did *not* use grouped_entities=True.
128
+
129
+ Args:
130
+ ner_results: The raw output from the NER pipeline (list of token dictionaries).
131
+ original_text: The original text input to extract entity words accurately.
132
+
133
+ Returns:
134
+ A list of dictionaries, each representing a combined entity with
135
+ 'entity_type', 'start', 'end', 'score', and 'word'.
136
+ """
137
+ if not ner_results:
138
+ return []
139
+
140
+ combined_entities = []
141
+ current_entity = None
142
+
143
+ for token in ner_results:
144
+ # Basic validation of token structure
145
+ if not all(k in token for k in ['entity', 'start', 'end', 'score']):
146
+ logger.warning(f"Skipping malformed token: {token}")
147
+ continue
148
+
149
+ # Skip 'O' tags (Outside any entity)
150
+ if token['entity'] == 'O':
151
+ # If we were tracking an entity, finalize it before moving on
152
+ if current_entity:
153
+ combined_entities.append(current_entity)
154
+ current_entity = None
155
+ continue
156
+
157
+ # Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-'
158
+ entity_tag = token['entity']
159
+ if entity_tag.startswith('B-') or entity_tag.startswith('I-'):
160
+ entity_type = entity_tag[2:]
161
+ else:
162
+ # Handle cases where the tag might not have B-/I- prefix (less common)
163
+ logger.warning(f"Unexpected entity tag format: {entity_tag}. Using as is.")
164
+ entity_type = entity_tag
165
+
166
+ # Start of a new entity ('B-') or continuation of a different entity type
167
+ if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)):
168
+ # Finalize the previous entity if it exists
169
+ if current_entity:
170
+ combined_entities.append(current_entity)
171
+
172
+ # Start the new entity
173
+ current_entity = {
174
+ 'entity_type': entity_type,
175
+ 'start': token['start'],
176
+ 'end': token['end'],
177
+ 'score': float(token['score']),
178
+ 'token_count': 1 # Keep track of tokens for averaging score
179
+ }
180
+
181
+ # Continuation of the current entity ('I-' and matching type)
182
+ elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type:
183
+ # Extend the end position
184
+ current_entity['end'] = token['end']
185
+ # Update the score (e.g., average)
186
+ current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1)
187
+ current_entity['token_count'] += 1
188
+
189
+ # Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
190
+ else:
191
+ logger.warning(f"Encountered unexpected token sequence at: {token}. Resetting current entity.")
192
+ if current_entity:
193
+ combined_entities.append(current_entity)
194
+ current_entity = None # Reset
195
+
196
+
197
+ # Add the last tracked entity if it exists
198
+ if current_entity:
199
+ combined_entities.append(current_entity)
200
+
201
+ # Extract the actual text 'word' for each combined entity
202
+ for entity in combined_entities:
203
+ try:
204
+ entity['word'] = original_text[entity['start']:entity['end']].strip()
205
+ # Remove internal helper key
206
+ if 'token_count' in entity:
207
+ del entity['token_count']
208
+ except IndexError:
209
+ logger.error(f"Index error extracting word for entity: {entity} with text length {len(original_text)}")
210
+ entity['word'] = "[Error extracting word]"
211
+
212
+
213
+ # Optional: Sort entities by start position
214
+ combined_entities.sort(key=lambda x: x['start'])
215
+
216
+ logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
217
+ return combined_entities
218
+
219
+ @staticmethod
220
+ def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
221
+ """
222
+ Process large text by splitting it into overlapping chunks, running NER
223
+ on each chunk, and then combining the results intelligently.
224
+
225
+ Args:
226
+ text: The large input text string.
227
+ model: The initialized NERModel instance.
228
+ chunk_size: The maximum size of each text chunk (in characters or tokens,
229
+ depending on the tokenizer's limits, often ~512 for BERT).
230
+ overlap: The number of characters/tokens to overlap between consecutive chunks
231
+ to ensure entities spanning chunk boundaries are captured.
232
+
233
+ Returns:
234
+ A list of combined entity dictionaries for the entire text.
235
+ """
236
+ if not text:
237
+ return []
238
+
239
+ # Use tokenizer max length if available and smaller than chunk_size
240
+ if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
241
+ tokenizer_max_len = model._tokenizer.model_max_length
242
+ if chunk_size > tokenizer_max_len:
243
+ logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
244
+ chunk_size = tokenizer_max_len
245
+ # Ensure overlap is reasonable compared to chunk size
246
+ if overlap >= chunk_size // 2:
247
+ logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
248
+ overlap = chunk_size // 4
249
+
250
+
251
+ logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
252
+ chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
253
+ logger.info(f"Split text into {len(chunks)} chunks.")
254
+
255
+ all_raw_results = []
256
+ total_processing_time = 0
257
+
258
+ for i, (chunk_text, start_pos) in enumerate(chunks):
259
+ logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
260
+ start_time = time.time()
261
+
262
+ # Get raw predictions for the current chunk
263
+ raw_results_chunk = model.predict(chunk_text)
264
+
265
+ chunk_processing_time = time.time() - start_time
266
+ total_processing_time += chunk_processing_time
267
+ logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
268
+
269
+
270
+ # Adjust entity positions relative to the original text
271
+ for result in raw_results_chunk:
272
+ # Check if 'start' and 'end' exist before adjusting
273
+ if 'start' in result and 'end' in result:
274
+ result['start'] += start_pos
275
+ result['end'] += start_pos
276
+ else:
277
+ logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
278
+
279
+
280
+ all_raw_results.extend(raw_results_chunk)
281
+
282
+ logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
283
+ logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
284
+
285
+ # Combine entities from all chunks, handling potential duplicates from overlap
286
+ # The combine_entities method needs refinement to handle overlaps better,
287
+ # e.g., by prioritizing entities from non-overlapped regions or merging based on confidence.
288
+ # For now, we use the existing combine_entities, which might create duplicates if
289
+ # an entity appears fully in the overlap region of two chunks.
290
+ # A more robust approach would involve deduplication based on start/end/type.
291
+ combined_entities = TextProcessor.combine_entities(all_raw_results, text)
292
+
293
+ # Simple deduplication based on exact start, end, and type
294
+ unique_entities = []
295
+ seen_entities = set()
296
+ for entity in combined_entities:
297
+ entity_key = (entity['start'], entity['end'], entity['entity_type'])
298
+ if entity_key not in seen_entities:
299
+ unique_entities.append(entity)
300
+ seen_entities.add(entity_key)
301
+ else:
302
+ logger.debug(f"Duplicate entity removed: {entity}")
303
+
304
+ logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
305
+ return unique_entities
306
+
307
+
308
+ @staticmethod
309
+ def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
310
+ """
311
+ Split text into potentially overlapping chunks, trying to respect word boundaries.
312
+
313
+ Args:
314
+ text: The input text string.
315
+ chunk_size: The target maximum size of each chunk.
316
+ overlap: The desired overlap between consecutive chunks.
317
+
318
+ Returns:
319
+ A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text).
320
+ """
321
+ if not text:
322
+ return []
323
+ if chunk_size <= overlap:
324
+ raise ValueError("chunk_size must be greater than overlap")
325
+ if chunk_size <= 0:
326
+ raise ValueError("chunk_size must be positive")
327
+
328
+
329
+ chunks = []
330
+ start = 0
331
+ text_len = len(text)
332
+
333
+ while start < text_len:
334
+ # Determine the ideal end position
335
+ end = start + chunk_size
336
+
337
+ # If the ideal end is beyond the text length, just take the rest
338
+ if end >= text_len:
339
+ chunks.append((text[start:], start))
340
+ break # We've reached the end
341
+
342
+ # Try to find a suitable split point (e.g., whitespace) near the ideal end
343
+ # Search backwards from the ideal end position within a reasonable window (e.g., overlap size)
344
+ split_pos = -1
345
+ search_start = max(start, end - overlap) # Don't search too far back
346
+ for i in range(end, search_start -1 , -1):
347
+ # Prefer splitting at whitespace
348
+ if text[i].isspace():
349
+ split_pos = i + 1 # Split *after* the space
350
+ break
351
+ # Consider splitting at punctuation as a fallback? (optional)
352
+ # import string
353
+ # if text[i] in string.punctuation:
354
+ # split_pos = i + 1
355
+ # break
356
+
357
+
358
+ # If no good split point found nearby, just cut at the chunk_size
359
+ if split_pos == -1 or split_pos <= start:
360
+ actual_end = end
361
+ logger.debug(f"No suitable whitespace found near char {end}, cutting at {actual_end}")
362
+ else:
363
+ actual_end = split_pos
364
+ logger.debug(f"Found whitespace split point at char {actual_end}")
365
+
366
+
367
+ # Ensure the chunk isn't empty if split_pos was too close to start
368
+ if actual_end <= start:
369
+ actual_end = end # Fallback to hard cut if split logic fails
370
+
371
+ # Add the chunk and its starting position
372
+ chunks.append((text[start:actual_end], start))
373
+
374
+ # Determine the start of the next chunk
375
+ # Move forward by chunk_size minus overlap, ensuring progress
376
+ next_start = start + (chunk_size - overlap)
377
+
378
+ # If we split at whitespace (actual_end), we can potentially start the next chunk
379
+ # right after the split point to avoid redundant processing of the overlap zone
380
+ # if the split was significantly before the ideal 'end'.
381
+ # However, the simple `next_start = start + (chunk_size - overlap)` is safer
382
+ # to ensure consistent overlap handling unless more complex logic is added.
383
+ # Let's stick to the simpler approach for now:
384
+ # next_start = actual_end - overlap # This could lead to variable overlap size
385
+
386
+ # Ensure we always make progress
387
+ if next_start <= start:
388
+ logger.warning("Chunking logic resulted in no progress. Moving start by 1.")
389
+ next_start = start + 1
390
+
391
+
392
+ start = next_start
393
+
394
+
395
+ return chunks
396
+