File size: 16,471 Bytes
3bdd5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68c4f8
3bdd5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68c4f8
3bdd5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68c4f8
3bdd5ce
 
f68c4f8
 
 
 
 
 
 
 
3bdd5ce
 
 
 
 
 
 
 
f68c4f8
 
 
 
3bdd5ce
 
 
f68c4f8
 
3bdd5ce
 
f68c4f8
3bdd5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
f68c4f8
 
3bdd5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68c4f8
3bdd5ce
 
f68c4f8
 
 
3bdd5ce
 
f68c4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bdd5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68c4f8
 
 
3bdd5ce
 
f68c4f8
3bdd5ce
f68c4f8
3bdd5ce
f68c4f8
 
 
 
 
 
 
3bdd5ce
f68c4f8
3bdd5ce
f68c4f8
 
 
3bdd5ce
f68c4f8
 
 
3bdd5ce
 
f68c4f8
3bdd5ce
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# ner_module.py
import torch
import time
from typing import List, Dict, Any, Tuple
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class NERModel:
    """
    A singleton class to manage the NER model loading and prediction.
    Ensures the potentially large model is loaded only once.
    """
    _instance = None
    _model = None
    _tokenizer = None
    _pipeline = None
    _model_name = None # Store model name used for initialization

    @classmethod
    def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"):
        """
        Singleton pattern: Get the existing instance or create a new one.
        Uses the specified model_name only during the first initialization.
        """
        if cls._instance is None:
            logger.info(f"Creating new NERModel instance with model: {model_name}")
            cls._instance = cls(model_name)
        elif cls._model_name != model_name:
             logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.")
        return cls._instance

    def __init__(self, model_name: str):
        """
        Initialize the model, tokenizer, and pipeline.
        Private constructor - use get_instance() instead.
        """
        if NERModel._instance is not None:
            raise Exception("This class is a singleton! Use get_instance() to get the object.")
        else:
            self.model_name = model_name
            NERModel._model_name = model_name # Store the model name
            self._load_model()
            NERModel._instance = self # Assign the instance here

    def _load_model(self):
        """Load the NER model and tokenizer from Hugging Face."""
        logger.info(f"Loading model: {self.model_name}")
        start_time = time.time()

        try:
            # Load tokenizer and model
            self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)

            # Check if the model is a PyTorch model for potential optimizations
            if isinstance(self._model, torch.nn.Module):
                self._model.eval()  # Set model to evaluation mode (important for inference)

            # Create the NER pipeline
            self._pipeline = pipeline(
                "ner",
                model=self._model,
                tokenizer=self._tokenizer,
                # grouped_entities=True # Uncomment if you want to use pipeline's built-in grouping
            )

            load_time = time.time() - start_time
            logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")

        except Exception as e:
            logger.error(f"Error loading model {self.model_name}: {e}")
            # Clean up partial loads if necessary
            self._tokenizer = None
            self._model = None
            self._pipeline = None
            # Re-raise the exception to signal failure
            raise

    def predict(self, text: str) -> List[Dict[str, Any]]:
        """
        Run NER prediction on the input text using the loaded pipeline.

        Args:
            text: The input string to perform NER on.

        Returns:
            A list of dictionaries, where each dictionary represents an entity
            identified by the pipeline.
        """
        if self._pipeline is None:
            logger.error("NER pipeline is not initialized. Cannot predict.")
            return [] # Return empty list or raise an error

        if not text or not isinstance(text, str):
            logger.warning("Prediction called with empty or invalid text.")
            return []

        logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
        try:
            # The pipeline handles tokenization and prediction
            results = self._pipeline(text)
            logger.debug(f"Prediction results: {results}")
            return results
        except Exception as e:
            logger.error(f"Error during NER prediction: {e}")
            return [] # Return empty list on error


class TextProcessor:
    """
    Provides static methods for processing text, specifically for NER tasks,
    including combining subword entities and handling large texts via chunking.
    """

    @staticmethod
    def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]:
        """
        Combine entities that might be split into subword tokens (B-TAG, I-TAG).
        This method assumes the pipeline did *not* use grouped_entities=True.

        Args:
            ner_results: The raw output from the NER pipeline (list of token dictionaries).
            original_text: The original text input to extract entity words accurately.

        Returns:
            A list of dictionaries, each representing a combined entity with
            'entity_type', 'start', 'end', 'score', and 'word'.
        """
        if not ner_results:
            return []

        combined_entities = []
        current_entity = None

        for token in ner_results:
            # Basic validation of token structure
            if not all(k in token for k in ['entity', 'start', 'end', 'score']):
                logger.warning(f"Skipping malformed token: {token}")
                continue

            # Skip 'O' tags (Outside any entity)
            if token['entity'] == 'O':
                # If we were tracking an entity, finalize it before moving on
                if current_entity:
                    combined_entities.append(current_entity)
                    current_entity = None
                continue

            # Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-'
            entity_tag = token['entity']
            if entity_tag.startswith('B-') or entity_tag.startswith('I-'):
                entity_type = entity_tag[2:]
            else:
                # Handle cases where the tag might not have B-/I- prefix (less common)
                entity_type = entity_tag

            # Start of a new entity ('B-') or continuation of a different entity type
            if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)):
                # Finalize the previous entity if it exists
                if current_entity:
                    combined_entities.append(current_entity)

                # Start the new entity
                current_entity = {
                    'entity_type': entity_type,
                    'start': token['start'],
                    'end': token['end'],
                    'score': float(token['score']),
                    'token_count': 1 # Keep track of tokens for averaging score
                }

            # Continuation of the current entity ('I-' and matching type)
            elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type:
                # Extend the end position
                current_entity['end'] = token['end']
                # Update the score (e.g., average)
                current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1)
                current_entity['token_count'] += 1

            # Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
            else:
                 logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.")
                 if current_entity:
                     combined_entities.append(current_entity)
                 # Try to create a new entity from this token 
                 current_entity = {
                     'entity_type': entity_type,
                     'start': token['start'],
                     'end': token['end'],
                     'score': float(token['score']),
                     'token_count': 1
                 }

        # Add the last tracked entity if it exists
        if current_entity:
            combined_entities.append(current_entity)

        # Extract the actual text 'word' for each combined entity
        for entity in combined_entities:
            try:
                # Ensure indices are valid
                start = max(0, min(entity['start'], len(original_text)))
                end = max(start, min(entity['end'], len(original_text)))
                entity['word'] = original_text[start:end].strip()
                # Remove internal helper key
                if 'token_count' in entity:
                    del entity['token_count']
            except Exception as e:
                logger.error(f"Error extracting word for entity: {entity}, error: {e}")
                entity['word'] = "[Error extracting word]"

        # Sort entities by start position
        combined_entities.sort(key=lambda x: x['start'])

        logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
        return combined_entities

    @staticmethod
    def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
        """
        Process large text by splitting it into overlapping chunks, running NER
        on each chunk, and then combining the results intelligently.

        Args:
            text: The large input text string.
            model: The initialized NERModel instance.
            chunk_size: The maximum size of each text chunk.
            overlap: The number of characters to overlap between consecutive chunks.

        Returns:
            A list of combined entity dictionaries for the entire text.
        """
        if not text:
            return []

        # Use tokenizer max length if available and smaller than chunk_size
        if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
            tokenizer_max_len = model._tokenizer.model_max_length
            if chunk_size > tokenizer_max_len:
                logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
                chunk_size = tokenizer_max_len
            # Ensure overlap is reasonable compared to chunk size
            if overlap >= chunk_size // 2:
                 logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
                 overlap = chunk_size // 4

        logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
        chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
        logger.info(f"Split text into {len(chunks)} chunks.")

        all_raw_results = []
        total_processing_time = 0

        for i, (chunk_text, start_pos) in enumerate(chunks):
            logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
            start_time = time.time()

            # Get raw predictions for the current chunk
            raw_results_chunk = model.predict(chunk_text)

            chunk_processing_time = time.time() - start_time
            total_processing_time += chunk_processing_time
            logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")

            # Adjust entity positions relative to the original text
            for result in raw_results_chunk:
                 # Check if 'start' and 'end' exist before adjusting
                 if 'start' in result and 'end' in result:
                     result['start'] += start_pos
                     result['end'] += start_pos
                 else:
                     logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")

            all_raw_results.extend(raw_results_chunk)

        logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
        logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")

        # Combine entities from all chunks
        combined_entities = TextProcessor.combine_entities(all_raw_results, text)

        # Deduplicate entities based on overlapping positions
        # Two entities are considered duplicates if they have the same type and
        # overlap by more than 50% of the shorter entity's length
        unique_entities = []
        for entity in combined_entities:
            is_duplicate = False
            # Calculate entity length for overlap comparison
            entity_length = entity['end'] - entity['start']
            
            for existing in unique_entities:
                if existing['entity_type'] == entity['entity_type']:
                    # Check for significant overlap
                    overlap_start = max(entity['start'], existing['start'])
                    overlap_end = min(entity['end'], existing['end'])
                    if overlap_start < overlap_end:  # They overlap
                        overlap_length = overlap_end - overlap_start
                        shorter_length = min(entity_length, existing['end'] - existing['start'])
                        
                        # If overlap is significant (>50% of shorter entity)
                        if overlap_length > 0.5 * shorter_length:
                            is_duplicate = True
                            # Keep the one with higher score
                            if entity['score'] > existing['score']:
                                # Replace the existing entity with this one
                                unique_entities.remove(existing)
                                is_duplicate = False
                            break
            
            if not is_duplicate:
                unique_entities.append(entity)

        logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
        return unique_entities

    @staticmethod
    def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
        """
        Split text into potentially overlapping chunks, trying to respect word boundaries.

        Args:
            text: The input text string.
            chunk_size: The target maximum size of each chunk.
            overlap: The desired overlap between consecutive chunks.

        Returns:
            A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text).
        """
        if not text:
            return []
        if chunk_size <= overlap:
             raise ValueError("chunk_size must be greater than overlap")
        if chunk_size <= 0:
             raise ValueError("chunk_size must be positive")

        chunks = []
        start = 0
        text_len = len(text)

        while start < text_len:
            # Determine the ideal end position
            end = min(start + chunk_size, text_len)
            
            # If we're at the end of the text, just use what's left
            if end >= text_len:
                chunks.append((text[start:], start))
                break

            # Try to find a suitable split point (whitespace) to ensure we don't cut words
            split_pos = -1
            # Search backwards from end to find a whitespace
            for i in range(end, max(start, end - overlap) - 1, -1):
                if i < text_len and text[i].isspace():
                    split_pos = i + 1  # Position after the space
                    break
            
            # If no good split found, just use the calculated end
            if split_pos == -1 or split_pos <= start:
                actual_end = end
            else:
                actual_end = split_pos
            
            # Add the chunk
            chunks.append((text[start:actual_end], start))
            
            # Calculate next start position, ensuring we make progress
            next_start = start + (actual_end - start - overlap)
            if next_start <= start:
                next_start = start + 1
            
            start = next_start

        return chunks