Abhigyan
commited on
Commit
·
e3f321e
1
Parent(s):
b99eab6
Add app.py
Browse files
app.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ner_module.py
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
from typing import List, Dict, Any, Tuple
|
5 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class NERModel:
|
13 |
+
"""
|
14 |
+
A singleton class to manage the NER model loading and prediction.
|
15 |
+
Ensures the potentially large model is loaded only once.
|
16 |
+
"""
|
17 |
+
_instance = None
|
18 |
+
_model = None
|
19 |
+
_tokenizer = None
|
20 |
+
_pipeline = None
|
21 |
+
_model_name = None # Store model name used for initialization
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"):
|
25 |
+
"""
|
26 |
+
Singleton pattern: Get the existing instance or create a new one.
|
27 |
+
Uses the specified model_name only during the first initialization.
|
28 |
+
"""
|
29 |
+
if cls._instance is None:
|
30 |
+
logger.info(f"Creating new NERModel instance with model: {model_name}")
|
31 |
+
cls._instance = cls(model_name)
|
32 |
+
elif cls._model_name != model_name:
|
33 |
+
logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.")
|
34 |
+
return cls._instance
|
35 |
+
|
36 |
+
def __init__(self, model_name: str):
|
37 |
+
"""
|
38 |
+
Initialize the model, tokenizer, and pipeline.
|
39 |
+
Private constructor - use get_instance() instead.
|
40 |
+
"""
|
41 |
+
if NERModel._instance is not None:
|
42 |
+
raise Exception("This class is a singleton! Use get_instance() to get the object.")
|
43 |
+
else:
|
44 |
+
self.model_name = model_name
|
45 |
+
NERModel._model_name = model_name # Store the model name
|
46 |
+
self._load_model()
|
47 |
+
NERModel._instance = self # Assign the instance here
|
48 |
+
|
49 |
+
def _load_model(self):
|
50 |
+
"""Load the NER model and tokenizer from Hugging Face."""
|
51 |
+
logger.info(f"Loading model: {self.model_name}")
|
52 |
+
start_time = time.time()
|
53 |
+
|
54 |
+
try:
|
55 |
+
# Load tokenizer and model
|
56 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
57 |
+
self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
|
58 |
+
|
59 |
+
# Check if the model is a PyTorch model for potential optimizations
|
60 |
+
if isinstance(self._model, torch.nn.Module):
|
61 |
+
self._model.eval() # Set model to evaluation mode (important for inference)
|
62 |
+
# self._model.share_memory() # share_memory() might not be needed unless using multiprocessing explicitly
|
63 |
+
|
64 |
+
# Create the NER pipeline
|
65 |
+
# Specify device=-1 for CPU, device=0 for first GPU, etc.
|
66 |
+
# Let pipeline decide device automatically by default, or specify if needed
|
67 |
+
self._pipeline = pipeline(
|
68 |
+
"ner",
|
69 |
+
model=self._model,
|
70 |
+
tokenizer=self._tokenizer,
|
71 |
+
# grouped_entities=True # Group subword tokens automatically (alternative to manual combination)
|
72 |
+
)
|
73 |
+
|
74 |
+
load_time = time.time() - start_time
|
75 |
+
logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
logger.error(f"Error loading model {self.model_name}: {e}")
|
79 |
+
# Clean up partial loads if necessary
|
80 |
+
self._tokenizer = None
|
81 |
+
self._model = None
|
82 |
+
self._pipeline = None
|
83 |
+
# Re-raise the exception to signal failure
|
84 |
+
raise
|
85 |
+
|
86 |
+
def predict(self, text: str) -> List[Dict[str, Any]]:
|
87 |
+
"""
|
88 |
+
Run NER prediction on the input text using the loaded pipeline.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
text: The input string to perform NER on.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
A list of dictionaries, where each dictionary represents an entity
|
95 |
+
identified by the pipeline. The exact format depends on the pipeline
|
96 |
+
configuration (e.g., grouped_entities).
|
97 |
+
"""
|
98 |
+
if self._pipeline is None:
|
99 |
+
logger.error("NER pipeline is not initialized. Cannot predict.")
|
100 |
+
return [] # Return empty list or raise an error
|
101 |
+
|
102 |
+
if not text or not isinstance(text, str):
|
103 |
+
logger.warning("Prediction called with empty or invalid text.")
|
104 |
+
return []
|
105 |
+
|
106 |
+
logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
|
107 |
+
try:
|
108 |
+
# The pipeline handles tokenization and prediction
|
109 |
+
results = self._pipeline(text)
|
110 |
+
logger.debug(f"Prediction results: {results}")
|
111 |
+
return results
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Error during NER prediction: {e}")
|
114 |
+
return [] # Return empty list on error
|
115 |
+
|
116 |
+
|
117 |
+
class TextProcessor:
|
118 |
+
"""
|
119 |
+
Provides static methods for processing text, specifically for NER tasks,
|
120 |
+
including combining subword entities and handling large texts via chunking.
|
121 |
+
"""
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]:
|
125 |
+
"""
|
126 |
+
Combine entities that might be split into subword tokens (B-TAG, I-TAG).
|
127 |
+
This method assumes the pipeline did *not* use grouped_entities=True.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
ner_results: The raw output from the NER pipeline (list of token dictionaries).
|
131 |
+
original_text: The original text input to extract entity words accurately.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
A list of dictionaries, each representing a combined entity with
|
135 |
+
'entity_type', 'start', 'end', 'score', and 'word'.
|
136 |
+
"""
|
137 |
+
if not ner_results:
|
138 |
+
return []
|
139 |
+
|
140 |
+
combined_entities = []
|
141 |
+
current_entity = None
|
142 |
+
|
143 |
+
for token in ner_results:
|
144 |
+
# Basic validation of token structure
|
145 |
+
if not all(k in token for k in ['entity', 'start', 'end', 'score']):
|
146 |
+
logger.warning(f"Skipping malformed token: {token}")
|
147 |
+
continue
|
148 |
+
|
149 |
+
# Skip 'O' tags (Outside any entity)
|
150 |
+
if token['entity'] == 'O':
|
151 |
+
# If we were tracking an entity, finalize it before moving on
|
152 |
+
if current_entity:
|
153 |
+
combined_entities.append(current_entity)
|
154 |
+
current_entity = None
|
155 |
+
continue
|
156 |
+
|
157 |
+
# Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-'
|
158 |
+
entity_tag = token['entity']
|
159 |
+
if entity_tag.startswith('B-') or entity_tag.startswith('I-'):
|
160 |
+
entity_type = entity_tag[2:]
|
161 |
+
else:
|
162 |
+
# Handle cases where the tag might not have B-/I- prefix (less common)
|
163 |
+
logger.warning(f"Unexpected entity tag format: {entity_tag}. Using as is.")
|
164 |
+
entity_type = entity_tag
|
165 |
+
|
166 |
+
# Start of a new entity ('B-') or continuation of a different entity type
|
167 |
+
if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)):
|
168 |
+
# Finalize the previous entity if it exists
|
169 |
+
if current_entity:
|
170 |
+
combined_entities.append(current_entity)
|
171 |
+
|
172 |
+
# Start the new entity
|
173 |
+
current_entity = {
|
174 |
+
'entity_type': entity_type,
|
175 |
+
'start': token['start'],
|
176 |
+
'end': token['end'],
|
177 |
+
'score': float(token['score']),
|
178 |
+
'token_count': 1 # Keep track of tokens for averaging score
|
179 |
+
}
|
180 |
+
|
181 |
+
# Continuation of the current entity ('I-' and matching type)
|
182 |
+
elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type:
|
183 |
+
# Extend the end position
|
184 |
+
current_entity['end'] = token['end']
|
185 |
+
# Update the score (e.g., average)
|
186 |
+
current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1)
|
187 |
+
current_entity['token_count'] += 1
|
188 |
+
|
189 |
+
# Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
|
190 |
+
else:
|
191 |
+
logger.warning(f"Encountered unexpected token sequence at: {token}. Resetting current entity.")
|
192 |
+
if current_entity:
|
193 |
+
combined_entities.append(current_entity)
|
194 |
+
current_entity = None # Reset
|
195 |
+
|
196 |
+
|
197 |
+
# Add the last tracked entity if it exists
|
198 |
+
if current_entity:
|
199 |
+
combined_entities.append(current_entity)
|
200 |
+
|
201 |
+
# Extract the actual text 'word' for each combined entity
|
202 |
+
for entity in combined_entities:
|
203 |
+
try:
|
204 |
+
entity['word'] = original_text[entity['start']:entity['end']].strip()
|
205 |
+
# Remove internal helper key
|
206 |
+
if 'token_count' in entity:
|
207 |
+
del entity['token_count']
|
208 |
+
except IndexError:
|
209 |
+
logger.error(f"Index error extracting word for entity: {entity} with text length {len(original_text)}")
|
210 |
+
entity['word'] = "[Error extracting word]"
|
211 |
+
|
212 |
+
|
213 |
+
# Optional: Sort entities by start position
|
214 |
+
combined_entities.sort(key=lambda x: x['start'])
|
215 |
+
|
216 |
+
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
|
217 |
+
return combined_entities
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
|
221 |
+
"""
|
222 |
+
Process large text by splitting it into overlapping chunks, running NER
|
223 |
+
on each chunk, and then combining the results intelligently.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
text: The large input text string.
|
227 |
+
model: The initialized NERModel instance.
|
228 |
+
chunk_size: The maximum size of each text chunk (in characters or tokens,
|
229 |
+
depending on the tokenizer's limits, often ~512 for BERT).
|
230 |
+
overlap: The number of characters/tokens to overlap between consecutive chunks
|
231 |
+
to ensure entities spanning chunk boundaries are captured.
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
A list of combined entity dictionaries for the entire text.
|
235 |
+
"""
|
236 |
+
if not text:
|
237 |
+
return []
|
238 |
+
|
239 |
+
# Use tokenizer max length if available and smaller than chunk_size
|
240 |
+
if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
|
241 |
+
tokenizer_max_len = model._tokenizer.model_max_length
|
242 |
+
if chunk_size > tokenizer_max_len:
|
243 |
+
logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
|
244 |
+
chunk_size = tokenizer_max_len
|
245 |
+
# Ensure overlap is reasonable compared to chunk size
|
246 |
+
if overlap >= chunk_size // 2:
|
247 |
+
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
|
248 |
+
overlap = chunk_size // 4
|
249 |
+
|
250 |
+
|
251 |
+
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
|
252 |
+
chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
|
253 |
+
logger.info(f"Split text into {len(chunks)} chunks.")
|
254 |
+
|
255 |
+
all_raw_results = []
|
256 |
+
total_processing_time = 0
|
257 |
+
|
258 |
+
for i, (chunk_text, start_pos) in enumerate(chunks):
|
259 |
+
logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
|
260 |
+
start_time = time.time()
|
261 |
+
|
262 |
+
# Get raw predictions for the current chunk
|
263 |
+
raw_results_chunk = model.predict(chunk_text)
|
264 |
+
|
265 |
+
chunk_processing_time = time.time() - start_time
|
266 |
+
total_processing_time += chunk_processing_time
|
267 |
+
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
|
268 |
+
|
269 |
+
|
270 |
+
# Adjust entity positions relative to the original text
|
271 |
+
for result in raw_results_chunk:
|
272 |
+
# Check if 'start' and 'end' exist before adjusting
|
273 |
+
if 'start' in result and 'end' in result:
|
274 |
+
result['start'] += start_pos
|
275 |
+
result['end'] += start_pos
|
276 |
+
else:
|
277 |
+
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
|
278 |
+
|
279 |
+
|
280 |
+
all_raw_results.extend(raw_results_chunk)
|
281 |
+
|
282 |
+
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
|
283 |
+
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
|
284 |
+
|
285 |
+
# Combine entities from all chunks, handling potential duplicates from overlap
|
286 |
+
# The combine_entities method needs refinement to handle overlaps better,
|
287 |
+
# e.g., by prioritizing entities from non-overlapped regions or merging based on confidence.
|
288 |
+
# For now, we use the existing combine_entities, which might create duplicates if
|
289 |
+
# an entity appears fully in the overlap region of two chunks.
|
290 |
+
# A more robust approach would involve deduplication based on start/end/type.
|
291 |
+
combined_entities = TextProcessor.combine_entities(all_raw_results, text)
|
292 |
+
|
293 |
+
# Simple deduplication based on exact start, end, and type
|
294 |
+
unique_entities = []
|
295 |
+
seen_entities = set()
|
296 |
+
for entity in combined_entities:
|
297 |
+
entity_key = (entity['start'], entity['end'], entity['entity_type'])
|
298 |
+
if entity_key not in seen_entities:
|
299 |
+
unique_entities.append(entity)
|
300 |
+
seen_entities.add(entity_key)
|
301 |
+
else:
|
302 |
+
logger.debug(f"Duplicate entity removed: {entity}")
|
303 |
+
|
304 |
+
logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
|
305 |
+
return unique_entities
|
306 |
+
|
307 |
+
|
308 |
+
@staticmethod
|
309 |
+
def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
|
310 |
+
"""
|
311 |
+
Split text into potentially overlapping chunks, trying to respect word boundaries.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
text: The input text string.
|
315 |
+
chunk_size: The target maximum size of each chunk.
|
316 |
+
overlap: The desired overlap between consecutive chunks.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text).
|
320 |
+
"""
|
321 |
+
if not text:
|
322 |
+
return []
|
323 |
+
if chunk_size <= overlap:
|
324 |
+
raise ValueError("chunk_size must be greater than overlap")
|
325 |
+
if chunk_size <= 0:
|
326 |
+
raise ValueError("chunk_size must be positive")
|
327 |
+
|
328 |
+
|
329 |
+
chunks = []
|
330 |
+
start = 0
|
331 |
+
text_len = len(text)
|
332 |
+
|
333 |
+
while start < text_len:
|
334 |
+
# Determine the ideal end position
|
335 |
+
end = start + chunk_size
|
336 |
+
|
337 |
+
# If the ideal end is beyond the text length, just take the rest
|
338 |
+
if end >= text_len:
|
339 |
+
chunks.append((text[start:], start))
|
340 |
+
break # We've reached the end
|
341 |
+
|
342 |
+
# Try to find a suitable split point (e.g., whitespace) near the ideal end
|
343 |
+
# Search backwards from the ideal end position within a reasonable window (e.g., overlap size)
|
344 |
+
split_pos = -1
|
345 |
+
search_start = max(start, end - overlap) # Don't search too far back
|
346 |
+
for i in range(end, search_start -1 , -1):
|
347 |
+
# Prefer splitting at whitespace
|
348 |
+
if text[i].isspace():
|
349 |
+
split_pos = i + 1 # Split *after* the space
|
350 |
+
break
|
351 |
+
# Consider splitting at punctuation as a fallback? (optional)
|
352 |
+
# import string
|
353 |
+
# if text[i] in string.punctuation:
|
354 |
+
# split_pos = i + 1
|
355 |
+
# break
|
356 |
+
|
357 |
+
|
358 |
+
# If no good split point found nearby, just cut at the chunk_size
|
359 |
+
if split_pos == -1 or split_pos <= start:
|
360 |
+
actual_end = end
|
361 |
+
logger.debug(f"No suitable whitespace found near char {end}, cutting at {actual_end}")
|
362 |
+
else:
|
363 |
+
actual_end = split_pos
|
364 |
+
logger.debug(f"Found whitespace split point at char {actual_end}")
|
365 |
+
|
366 |
+
|
367 |
+
# Ensure the chunk isn't empty if split_pos was too close to start
|
368 |
+
if actual_end <= start:
|
369 |
+
actual_end = end # Fallback to hard cut if split logic fails
|
370 |
+
|
371 |
+
# Add the chunk and its starting position
|
372 |
+
chunks.append((text[start:actual_end], start))
|
373 |
+
|
374 |
+
# Determine the start of the next chunk
|
375 |
+
# Move forward by chunk_size minus overlap, ensuring progress
|
376 |
+
next_start = start + (chunk_size - overlap)
|
377 |
+
|
378 |
+
# If we split at whitespace (actual_end), we can potentially start the next chunk
|
379 |
+
# right after the split point to avoid redundant processing of the overlap zone
|
380 |
+
# if the split was significantly before the ideal 'end'.
|
381 |
+
# However, the simple `next_start = start + (chunk_size - overlap)` is safer
|
382 |
+
# to ensure consistent overlap handling unless more complex logic is added.
|
383 |
+
# Let's stick to the simpler approach for now:
|
384 |
+
# next_start = actual_end - overlap # This could lead to variable overlap size
|
385 |
+
|
386 |
+
# Ensure we always make progress
|
387 |
+
if next_start <= start:
|
388 |
+
logger.warning("Chunking logic resulted in no progress. Moving start by 1.")
|
389 |
+
next_start = start + 1
|
390 |
+
|
391 |
+
|
392 |
+
start = next_start
|
393 |
+
|
394 |
+
|
395 |
+
return chunks
|
396 |
+
|