Abhigyan commited on
Commit
f68c4f8
·
1 Parent(s): 3bdd5ce
Files changed (3) hide show
  1. __pycache__/ner_module.cpython-310.pyc +0 -0
  2. app.py +184 -381
  3. ner_module.py +68 -86
__pycache__/ner_module.cpython-310.pyc ADDED
Binary file (9.93 kB). View file
 
app.py CHANGED
@@ -1,396 +1,199 @@
1
- # ner_module.py
2
- import torch
 
3
  import time
4
- from typing import List, Dict, Any, Tuple
5
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
6
  import logging
7
 
8
- # Configure logging
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
  logger = logging.getLogger(__name__)
11
 
12
- class NERModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
- A singleton class to manage the NER model loading and prediction.
15
- Ensures the potentially large model is loaded only once.
16
  """
17
- _instance = None
18
- _model = None
19
- _tokenizer = None
20
- _pipeline = None
21
- _model_name = None # Store model name used for initialization
22
-
23
- @classmethod
24
- def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"):
25
- """
26
- Singleton pattern: Get the existing instance or create a new one.
27
- Uses the specified model_name only during the first initialization.
28
- """
29
- if cls._instance is None:
30
- logger.info(f"Creating new NERModel instance with model: {model_name}")
31
- cls._instance = cls(model_name)
32
- elif cls._model_name != model_name:
33
- logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.")
34
- return cls._instance
35
-
36
- def __init__(self, model_name: str):
37
- """
38
- Initialize the model, tokenizer, and pipeline.
39
- Private constructor - use get_instance() instead.
40
- """
41
- if NERModel._instance is not None:
42
- raise Exception("This class is a singleton! Use get_instance() to get the object.")
43
- else:
44
- self.model_name = model_name
45
- NERModel._model_name = model_name # Store the model name
46
- self._load_model()
47
- NERModel._instance = self # Assign the instance here
48
-
49
- def _load_model(self):
50
- """Load the NER model and tokenizer from Hugging Face."""
51
- logger.info(f"Loading model: {self.model_name}")
52
- start_time = time.time()
53
-
54
- try:
55
- # Load tokenizer and model
56
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
57
- self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
58
-
59
- # Check if the model is a PyTorch model for potential optimizations
60
- if isinstance(self._model, torch.nn.Module):
61
- self._model.eval() # Set model to evaluation mode (important for inference)
62
- # self._model.share_memory() # share_memory() might not be needed unless using multiprocessing explicitly
63
-
64
- # Create the NER pipeline
65
- # Specify device=-1 for CPU, device=0 for first GPU, etc.
66
- # Let pipeline decide device automatically by default, or specify if needed
67
- self._pipeline = pipeline(
68
- "ner",
69
- model=self._model,
70
- tokenizer=self._tokenizer,
71
- # grouped_entities=True # Group subword tokens automatically (alternative to manual combination)
72
- )
73
-
74
- load_time = time.time() - start_time
75
- logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")
76
-
77
- except Exception as e:
78
- logger.error(f"Error loading model {self.model_name}: {e}")
79
- # Clean up partial loads if necessary
80
- self._tokenizer = None
81
- self._model = None
82
- self._pipeline = None
83
- # Re-raise the exception to signal failure
84
- raise
85
-
86
- def predict(self, text: str) -> List[Dict[str, Any]]:
87
- """
88
- Run NER prediction on the input text using the loaded pipeline.
89
-
90
- Args:
91
- text: The input string to perform NER on.
92
-
93
- Returns:
94
- A list of dictionaries, where each dictionary represents an entity
95
- identified by the pipeline. The exact format depends on the pipeline
96
- configuration (e.g., grouped_entities).
97
- """
98
- if self._pipeline is None:
99
- logger.error("NER pipeline is not initialized. Cannot predict.")
100
- return [] # Return empty list or raise an error
101
-
102
- if not text or not isinstance(text, str):
103
- logger.warning("Prediction called with empty or invalid text.")
104
- return []
105
-
106
- logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
107
- try:
108
- # The pipeline handles tokenization and prediction
109
- results = self._pipeline(text)
110
- logger.debug(f"Prediction results: {results}")
111
- return results
112
- except Exception as e:
113
- logger.error(f"Error during NER prediction: {e}")
114
- return [] # Return empty list on error
115
-
116
-
117
- class TextProcessor:
118
  """
119
- Provides static methods for processing text, specifically for NER tasks,
120
- including combining subword entities and handling large texts via chunking.
121
  """
122
-
123
- @staticmethod
124
- def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]:
125
- """
126
- Combine entities that might be split into subword tokens (B-TAG, I-TAG).
127
- This method assumes the pipeline did *not* use grouped_entities=True.
128
-
129
- Args:
130
- ner_results: The raw output from the NER pipeline (list of token dictionaries).
131
- original_text: The original text input to extract entity words accurately.
132
-
133
- Returns:
134
- A list of dictionaries, each representing a combined entity with
135
- 'entity_type', 'start', 'end', 'score', and 'word'.
136
- """
137
- if not ner_results:
138
- return []
139
-
140
- combined_entities = []
141
- current_entity = None
142
-
143
- for token in ner_results:
144
- # Basic validation of token structure
145
- if not all(k in token for k in ['entity', 'start', 'end', 'score']):
146
- logger.warning(f"Skipping malformed token: {token}")
147
- continue
148
-
149
- # Skip 'O' tags (Outside any entity)
150
- if token['entity'] == 'O':
151
- # If we were tracking an entity, finalize it before moving on
152
- if current_entity:
153
- combined_entities.append(current_entity)
154
- current_entity = None
155
- continue
156
-
157
- # Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-'
158
- entity_tag = token['entity']
159
- if entity_tag.startswith('B-') or entity_tag.startswith('I-'):
160
- entity_type = entity_tag[2:]
161
- else:
162
- # Handle cases where the tag might not have B-/I- prefix (less common)
163
- logger.warning(f"Unexpected entity tag format: {entity_tag}. Using as is.")
164
- entity_type = entity_tag
165
-
166
- # Start of a new entity ('B-') or continuation of a different entity type
167
- if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)):
168
- # Finalize the previous entity if it exists
169
- if current_entity:
170
- combined_entities.append(current_entity)
171
-
172
- # Start the new entity
173
- current_entity = {
174
- 'entity_type': entity_type,
175
- 'start': token['start'],
176
- 'end': token['end'],
177
- 'score': float(token['score']),
178
- 'token_count': 1 # Keep track of tokens for averaging score
179
- }
180
-
181
- # Continuation of the current entity ('I-' and matching type)
182
- elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type:
183
- # Extend the end position
184
- current_entity['end'] = token['end']
185
- # Update the score (e.g., average)
186
- current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1)
187
- current_entity['token_count'] += 1
188
-
189
- # Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
190
- else:
191
- logger.warning(f"Encountered unexpected token sequence at: {token}. Resetting current entity.")
192
- if current_entity:
193
- combined_entities.append(current_entity)
194
- current_entity = None # Reset
195
-
196
-
197
- # Add the last tracked entity if it exists
198
- if current_entity:
199
- combined_entities.append(current_entity)
200
-
201
- # Extract the actual text 'word' for each combined entity
202
- for entity in combined_entities:
203
- try:
204
- entity['word'] = original_text[entity['start']:entity['end']].strip()
205
- # Remove internal helper key
206
- if 'token_count' in entity:
207
- del entity['token_count']
208
- except IndexError:
209
- logger.error(f"Index error extracting word for entity: {entity} with text length {len(original_text)}")
210
- entity['word'] = "[Error extracting word]"
211
-
212
-
213
- # Optional: Sort entities by start position
214
- combined_entities.sort(key=lambda x: x['start'])
215
-
216
- logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
217
- return combined_entities
218
-
219
- @staticmethod
220
- def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
221
- """
222
- Process large text by splitting it into overlapping chunks, running NER
223
- on each chunk, and then combining the results intelligently.
224
-
225
- Args:
226
- text: The large input text string.
227
- model: The initialized NERModel instance.
228
- chunk_size: The maximum size of each text chunk (in characters or tokens,
229
- depending on the tokenizer's limits, often ~512 for BERT).
230
- overlap: The number of characters/tokens to overlap between consecutive chunks
231
- to ensure entities spanning chunk boundaries are captured.
232
-
233
- Returns:
234
- A list of combined entity dictionaries for the entire text.
235
- """
236
- if not text:
237
- return []
238
-
239
- # Use tokenizer max length if available and smaller than chunk_size
240
- if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
241
- tokenizer_max_len = model._tokenizer.model_max_length
242
- if chunk_size > tokenizer_max_len:
243
- logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
244
- chunk_size = tokenizer_max_len
245
- # Ensure overlap is reasonable compared to chunk size
246
- if overlap >= chunk_size // 2:
247
- logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
248
- overlap = chunk_size // 4
249
-
250
-
251
- logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
252
- chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
253
- logger.info(f"Split text into {len(chunks)} chunks.")
254
-
255
- all_raw_results = []
256
- total_processing_time = 0
257
-
258
- for i, (chunk_text, start_pos) in enumerate(chunks):
259
- logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
260
- start_time = time.time()
261
-
262
- # Get raw predictions for the current chunk
263
- raw_results_chunk = model.predict(chunk_text)
264
-
265
- chunk_processing_time = time.time() - start_time
266
- total_processing_time += chunk_processing_time
267
- logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
268
-
269
-
270
- # Adjust entity positions relative to the original text
271
- for result in raw_results_chunk:
272
- # Check if 'start' and 'end' exist before adjusting
273
- if 'start' in result and 'end' in result:
274
- result['start'] += start_pos
275
- result['end'] += start_pos
276
- else:
277
- logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
278
-
279
-
280
- all_raw_results.extend(raw_results_chunk)
281
-
282
- logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
283
- logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
284
-
285
- # Combine entities from all chunks, handling potential duplicates from overlap
286
- # The combine_entities method needs refinement to handle overlaps better,
287
- # e.g., by prioritizing entities from non-overlapped regions or merging based on confidence.
288
- # For now, we use the existing combine_entities, which might create duplicates if
289
- # an entity appears fully in the overlap region of two chunks.
290
- # A more robust approach would involve deduplication based on start/end/type.
291
- combined_entities = TextProcessor.combine_entities(all_raw_results, text)
292
-
293
- # Simple deduplication based on exact start, end, and type
294
- unique_entities = []
295
- seen_entities = set()
296
- for entity in combined_entities:
297
- entity_key = (entity['start'], entity['end'], entity['entity_type'])
298
- if entity_key not in seen_entities:
299
- unique_entities.append(entity)
300
- seen_entities.add(entity_key)
301
  else:
302
- logger.debug(f"Duplicate entity removed: {entity}")
303
-
304
- logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
305
- return unique_entities
306
-
307
-
308
- @staticmethod
309
- def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
310
- """
311
- Split text into potentially overlapping chunks, trying to respect word boundaries.
312
-
313
- Args:
314
- text: The input text string.
315
- chunk_size: The target maximum size of each chunk.
316
- overlap: The desired overlap between consecutive chunks.
317
-
318
- Returns:
319
- A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text).
320
- """
321
- if not text:
322
- return []
323
- if chunk_size <= overlap:
324
- raise ValueError("chunk_size must be greater than overlap")
325
- if chunk_size <= 0:
326
- raise ValueError("chunk_size must be positive")
327
-
328
-
329
- chunks = []
330
- start = 0
331
- text_len = len(text)
332
-
333
- while start < text_len:
334
- # Determine the ideal end position
335
- end = start + chunk_size
336
-
337
- # If the ideal end is beyond the text length, just take the rest
338
- if end >= text_len:
339
- chunks.append((text[start:], start))
340
- break # We've reached the end
341
-
342
- # Try to find a suitable split point (e.g., whitespace) near the ideal end
343
- # Search backwards from the ideal end position within a reasonable window (e.g., overlap size)
344
- split_pos = -1
345
- search_start = max(start, end - overlap) # Don't search too far back
346
- for i in range(end, search_start -1 , -1):
347
- # Prefer splitting at whitespace
348
- if text[i].isspace():
349
- split_pos = i + 1 # Split *after* the space
350
- break
351
- # Consider splitting at punctuation as a fallback? (optional)
352
- # import string
353
- # if text[i] in string.punctuation:
354
- # split_pos = i + 1
355
- # break
356
-
357
-
358
- # If no good split point found nearby, just cut at the chunk_size
359
- if split_pos == -1 or split_pos <= start:
360
- actual_end = end
361
- logger.debug(f"No suitable whitespace found near char {end}, cutting at {actual_end}")
362
- else:
363
- actual_end = split_pos
364
- logger.debug(f"Found whitespace split point at char {actual_end}")
365
-
366
-
367
- # Ensure the chunk isn't empty if split_pos was too close to start
368
- if actual_end <= start:
369
- actual_end = end # Fallback to hard cut if split logic fails
370
 
371
- # Add the chunk and its starting position
372
- chunks.append((text[start:actual_end], start))
373
-
374
- # Determine the start of the next chunk
375
- # Move forward by chunk_size minus overlap, ensuring progress
376
- next_start = start + (chunk_size - overlap)
377
-
378
- # If we split at whitespace (actual_end), we can potentially start the next chunk
379
- # right after the split point to avoid redundant processing of the overlap zone
380
- # if the split was significantly before the ideal 'end'.
381
- # However, the simple `next_start = start + (chunk_size - overlap)` is safer
382
- # to ensure consistent overlap handling unless more complex logic is added.
383
- # Let's stick to the simpler approach for now:
384
- # next_start = actual_end - overlap # This could lead to variable overlap size
385
-
386
- # Ensure we always make progress
387
- if next_start <= start:
388
- logger.warning("Chunking logic resulted in no progress. Moving start by 1.")
389
- next_start = start + 1
390
-
391
-
392
- start = next_start
393
 
 
 
394
 
395
- return chunks
 
 
396
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
+ from ner_module import NERModel, TextProcessor
4
  import time
 
 
5
  import logging
6
 
7
+ # Configure logging (optional, but helpful for debugging Streamlit apps)
8
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
9
  logger = logging.getLogger(__name__)
10
 
11
+ # --- Configuration ---
12
+ DEFAULT_MODEL = "Davlan/bert-base-multilingual-cased-ner-hrl"
13
+ # Alternative models (ensure they are compatible TokenClassification models)
14
+ # DEFAULT_MODEL = "dslim/bert-base-NER" # English NER
15
+ # DEFAULT_MODEL = "xlm-roberta-large-finetuned-conll03-english" # Another English option
16
+
17
+ DEFAULT_TEXT = """
18
+ Angela Merkel met Emmanuel Macron in Berlin on Tuesday to discuss the future of the European Union.
19
+ They visited the Brandenburg Gate and enjoyed some Currywurst. Later, they flew to Paris.
20
+ John Doe from New York works at Google LLC.
21
+ """
22
+ CHUNK_SIZE_DEFAULT = 500 # Slightly less than common 512 limit to be safe
23
+ OVERLAP_DEFAULT = 50
24
+
25
+ # --- Caching ---
26
+ @st.cache_resource(show_spinner="Loading NER Model...")
27
+ def load_ner_model(model_name: str):
28
  """
29
+ Loads the NERModel using the singleton pattern and caches the instance.
30
+ Streamlit's cache_resource is ideal for heavy objects like models.
31
  """
32
+ try:
33
+ logger.info(f"Attempting to load model: {model_name}")
34
+ model_instance = NERModel.get_instance(model_name=model_name)
35
+ return model_instance
36
+ except Exception as e:
37
+ st.error(f"Failed to load model '{model_name}'. Error: {e}", icon="🚨")
38
+ logger.error(f"Fatal error loading model {model_name}: {e}")
39
+ return None
40
+
41
+ # --- Helper Functions ---
42
+ def get_color_for_entity(entity_type: str) -> str:
43
+ """Assigns a color based on the entity type for visualization."""
44
+ # Simple color mapping, can be expanded
45
+ colors = {
46
+ "PER": "#faa", # Light red for Person
47
+ "ORG": "#afa", # Light green for Organization
48
+ "LOC": "#aaf", # Light blue for Location
49
+ "MISC": "#ffc", # Light yellow for Miscellaneous
50
+ # Add more colors as needed based on model's entity types
51
+ }
52
+ # Default color if type not found
53
+ return colors.get(entity_type.upper(), "#ddd") # Light grey default
54
+
55
+ def highlight_entities(text: str, entities: list) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  """
57
+ Generates an HTML string with entities highlighted using spans and colors.
58
+ Sorts entities by start position descending to handle nested entities correctly.
59
  """
60
+ if not entities:
61
+ return text
62
+
63
+ # Sort entities by start index in descending order
64
+ # This ensures that inner entities are processed before outer ones if they overlap
65
+ entities.sort(key=lambda x: x['start'], reverse=True)
66
+
67
+ highlighted_text = text
68
+ for entity in entities:
69
+ start = entity['start']
70
+ end = entity['end']
71
+ entity_type = entity['entity_type']
72
+ word = entity['word'] # Use the extracted word for the title/tooltip
73
+ color = get_color_for_entity(entity_type)
74
+
75
+ # Create the highlighted span
76
+ highlight = (
77
+ f'<span style="background-color: {color}; padding: 0.2em 0.3em; '
78
+ f'margin: 0 0.15em; line-height: 1; border-radius: 0.3em;" '
79
+ f'title="{entity_type}: {word} (Score: {entity.get("score", 0):.2f})">' # Tooltip
80
+ f'{highlighted_text[start:end]}' # Get the original text slice
81
+ f'<sup style="font-size: 0.7em; font-weight: bold; margin-left: 2px; color: #555;">{entity_type}</sup>' # Small label
82
+ f'</span>'
83
+ )
84
+
85
+ # Replace the original text portion with the highlighted version
86
+ # Working backwards prevents index issues from altering string length
87
+ highlighted_text = highlighted_text[:start] + highlight + highlighted_text[end:]
88
+
89
+ return highlighted_text
90
+
91
+
92
+ # --- Streamlit App UI ---
93
+ st.set_page_config(layout="wide", page_title="NER Demo")
94
+
95
+ st.title("📝 Named Entity Recognition (NER) Demo")
96
+ st.markdown("Highlight Persons (PER), Organizations (ORG), Locations (LOC), and Miscellaneous (MISC) entities in text using a Hugging Face Transformer model.")
97
+
98
+ # Model selection fixed to default for simplicity
99
+ model_name = DEFAULT_MODEL
100
+
101
+ # Load the model (cached)
102
+ ner_model = load_ner_model(model_name)
103
+
104
+ if ner_model: # Proceed only if the model loaded successfully
105
+ st.success(f"Model '{ner_model.model_name}' loaded successfully.", icon="✅")
106
+
107
+ # --- Input & Controls ---
108
+ col1, col2 = st.columns([3, 1]) # Input area takes more space
109
+
110
+ with col1:
111
+ st.subheader("Input Text")
112
+ # Use session state to keep text area content persistent across reruns
113
+ if 'text_input' not in st.session_state:
114
+ st.session_state.text_input = DEFAULT_TEXT
115
+ text_input = st.text_area("Enter text here:", value=st.session_state.text_input, height=250, key="text_area_input")
116
+ st.session_state.text_input = text_input # Update session state on change
117
+
118
+ with col2:
119
+ st.subheader("Options")
120
+ use_chunking = st.checkbox("Process as Large Text (Chunking)", value=True)
121
+
122
+ chunk_size = CHUNK_SIZE_DEFAULT
123
+ overlap = OVERLAP_DEFAULT
124
+
125
+ if use_chunking:
126
+ chunk_size = st.slider("Chunk Size (chars)", min_value=100, max_value=1024, value=CHUNK_SIZE_DEFAULT, step=10)
127
+ overlap = st.slider("Overlap (chars)", min_value=10, max_value=chunk_size // 2, value=OVERLAP_DEFAULT, step=5)
128
+
129
+ process_button = st.button(" Analyze Text", type="primary", use_container_width=True)
130
+
131
+ # --- Processing and Output ---
132
+ if process_button and text_input:
133
+ start_process_time = time.time()
134
+ st.markdown("---") # Separator
135
+ st.subheader("Analysis Results")
136
+
137
+ with st.spinner("Analyzing text... Please wait."):
138
+ if use_chunking:
139
+ logger.info(f"Processing with chunking: size={chunk_size}, overlap={overlap}")
140
+ entities = TextProcessor.process_large_text(
141
+ text=text_input,
142
+ model=ner_model,
143
+ chunk_size=chunk_size,
144
+ overlap=overlap
145
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  else:
147
+ logger.info("Processing without chunking (potential truncation for long text)")
148
+ entities = TextProcessor.process_large_text(
149
+ text=text_input,
150
+ model=ner_model,
151
+ chunk_size=max(len(text_input), 512), # Use text length or a large value
152
+ overlap=0 # No overlap needed for single chunk
153
+ )
154
+
155
+ end_process_time = time.time()
156
+ processing_duration = end_process_time - start_process_time
157
+ st.info(f"Analysis completed in {processing_duration:.2f} seconds. Found {len(entities)} entities.", icon="⏱️")
158
+
159
+ if entities:
160
+ # Display highlighted text
161
+ st.markdown("#### Highlighted Text:")
162
+ highlighted_html = highlight_entities(text_input, entities)
163
+ # Use st.markdown to render the HTML
164
+ st.markdown(highlighted_html, unsafe_allow_html=True)
165
+
166
+ # Display entities in a table-like format
167
+ st.markdown("#### Extracted Entities:")
168
+ # Sort entities by appearance order for the list
169
+ entities.sort(key=lambda x: x['start'])
170
+
171
+ # Use columns for a cleaner layout
172
+ cols = st.columns(3) # Adjust number of columns as needed
173
+ col_idx = 0
174
+ for entity in entities:
175
+ with cols[col_idx % len(cols)]:
176
+ st.markdown(
177
+ f"**{entity['entity_type']}** `{entity['score']:.2f}`: "
178
+ f"{entity['word']} ({entity['start']}-{entity['end']})"
179
+ )
180
+ col_idx += 1
181
+
182
+ # Alternative display as an expander with detailed info
183
+ with st.expander("Show Detailed Entity List", expanded=False):
184
+ for entity in entities:
185
+ st.write(f"- **{entity['entity_type']}**: {entity['word']} (Score: {entity['score']:.2f}, Position: {entity['start']}-{entity['end']})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ else:
188
+ st.warning("No entities found in the provided text.", icon="❓")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ elif process_button and not text_input:
191
+ st.warning("Please enter some text to analyze.", icon="⚠️")
192
 
193
+ else:
194
+ # This block runs if the model failed to load
195
+ st.error("NER model could not be loaded. Please check the logs or model name. The application cannot proceed.", icon="🛑")
196
 
197
+ # Add footer or instructions
198
+ st.markdown("---")
199
+ st.caption("Powered by Hugging Face Transformers and Streamlit.")
ner_module.py CHANGED
@@ -59,16 +59,13 @@ class NERModel:
59
  # Check if the model is a PyTorch model for potential optimizations
60
  if isinstance(self._model, torch.nn.Module):
61
  self._model.eval() # Set model to evaluation mode (important for inference)
62
- # self._model.share_memory() # share_memory() might not be needed unless using multiprocessing explicitly
63
 
64
  # Create the NER pipeline
65
- # Specify device=-1 for CPU, device=0 for first GPU, etc.
66
- # Let pipeline decide device automatically by default, or specify if needed
67
  self._pipeline = pipeline(
68
  "ner",
69
  model=self._model,
70
  tokenizer=self._tokenizer,
71
- # grouped_entities=True # Group subword tokens automatically (alternative to manual combination)
72
  )
73
 
74
  load_time = time.time() - start_time
@@ -92,8 +89,7 @@ class NERModel:
92
 
93
  Returns:
94
  A list of dictionaries, where each dictionary represents an entity
95
- identified by the pipeline. The exact format depends on the pipeline
96
- configuration (e.g., grouped_entities).
97
  """
98
  if self._pipeline is None:
99
  logger.error("NER pipeline is not initialized. Cannot predict.")
@@ -160,7 +156,6 @@ class TextProcessor:
160
  entity_type = entity_tag[2:]
161
  else:
162
  # Handle cases where the tag might not have B-/I- prefix (less common)
163
- logger.warning(f"Unexpected entity tag format: {entity_tag}. Using as is.")
164
  entity_type = entity_tag
165
 
166
  # Start of a new entity ('B-') or continuation of a different entity type
@@ -188,11 +183,17 @@ class TextProcessor:
188
 
189
  # Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
190
  else:
191
- logger.warning(f"Encountered unexpected token sequence at: {token}. Resetting current entity.")
192
  if current_entity:
193
  combined_entities.append(current_entity)
194
- current_entity = None # Reset
195
-
 
 
 
 
 
 
196
 
197
  # Add the last tracked entity if it exists
198
  if current_entity:
@@ -201,16 +202,18 @@ class TextProcessor:
201
  # Extract the actual text 'word' for each combined entity
202
  for entity in combined_entities:
203
  try:
204
- entity['word'] = original_text[entity['start']:entity['end']].strip()
 
 
 
205
  # Remove internal helper key
206
  if 'token_count' in entity:
207
  del entity['token_count']
208
- except IndexError:
209
- logger.error(f"Index error extracting word for entity: {entity} with text length {len(original_text)}")
210
  entity['word'] = "[Error extracting word]"
211
 
212
-
213
- # Optional: Sort entities by start position
214
  combined_entities.sort(key=lambda x: x['start'])
215
 
216
  logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
@@ -225,10 +228,8 @@ class TextProcessor:
225
  Args:
226
  text: The large input text string.
227
  model: The initialized NERModel instance.
228
- chunk_size: The maximum size of each text chunk (in characters or tokens,
229
- depending on the tokenizer's limits, often ~512 for BERT).
230
- overlap: The number of characters/tokens to overlap between consecutive chunks
231
- to ensure entities spanning chunk boundaries are captured.
232
 
233
  Returns:
234
  A list of combined entity dictionaries for the entire text.
@@ -247,7 +248,6 @@ class TextProcessor:
247
  logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
248
  overlap = chunk_size // 4
249
 
250
-
251
  logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
252
  chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
253
  logger.info(f"Split text into {len(chunks)} chunks.")
@@ -266,7 +266,6 @@ class TextProcessor:
266
  total_processing_time += chunk_processing_time
267
  logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
268
 
269
-
270
  # Adjust entity positions relative to the original text
271
  for result in raw_results_chunk:
272
  # Check if 'start' and 'end' exist before adjusting
@@ -276,35 +275,48 @@ class TextProcessor:
276
  else:
277
  logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
278
 
279
-
280
  all_raw_results.extend(raw_results_chunk)
281
 
282
  logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
283
  logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
284
 
285
- # Combine entities from all chunks, handling potential duplicates from overlap
286
- # The combine_entities method needs refinement to handle overlaps better,
287
- # e.g., by prioritizing entities from non-overlapped regions or merging based on confidence.
288
- # For now, we use the existing combine_entities, which might create duplicates if
289
- # an entity appears fully in the overlap region of two chunks.
290
- # A more robust approach would involve deduplication based on start/end/type.
291
  combined_entities = TextProcessor.combine_entities(all_raw_results, text)
292
 
293
- # Simple deduplication based on exact start, end, and type
 
 
294
  unique_entities = []
295
- seen_entities = set()
296
  for entity in combined_entities:
297
- entity_key = (entity['start'], entity['end'], entity['entity_type'])
298
- if entity_key not in seen_entities:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  unique_entities.append(entity)
300
- seen_entities.add(entity_key)
301
- else:
302
- logger.debug(f"Duplicate entity removed: {entity}")
303
 
304
  logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
305
  return unique_entities
306
 
307
-
308
  @staticmethod
309
  def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
310
  """
@@ -325,71 +337,41 @@ class TextProcessor:
325
  if chunk_size <= 0:
326
  raise ValueError("chunk_size must be positive")
327
 
328
-
329
  chunks = []
330
  start = 0
331
  text_len = len(text)
332
 
333
  while start < text_len:
334
  # Determine the ideal end position
335
- end = start + chunk_size
336
-
337
- # If the ideal end is beyond the text length, just take the rest
338
  if end >= text_len:
339
  chunks.append((text[start:], start))
340
- break # We've reached the end
341
 
342
- # Try to find a suitable split point (e.g., whitespace) near the ideal end
343
- # Search backwards from the ideal end position within a reasonable window (e.g., overlap size)
344
  split_pos = -1
345
- search_start = max(start, end - overlap) # Don't search too far back
346
- for i in range(end, search_start -1 , -1):
347
- # Prefer splitting at whitespace
348
- if text[i].isspace():
349
- split_pos = i + 1 # Split *after* the space
350
- break
351
- # Consider splitting at punctuation as a fallback? (optional)
352
- # import string
353
- # if text[i] in string.punctuation:
354
- # split_pos = i + 1
355
- # break
356
-
357
-
358
- # If no good split point found nearby, just cut at the chunk_size
359
  if split_pos == -1 or split_pos <= start:
360
- actual_end = end
361
- logger.debug(f"No suitable whitespace found near char {end}, cutting at {actual_end}")
362
  else:
363
- actual_end = split_pos
364
- logger.debug(f"Found whitespace split point at char {actual_end}")
365
-
366
-
367
- # Ensure the chunk isn't empty if split_pos was too close to start
368
- if actual_end <= start:
369
- actual_end = end # Fallback to hard cut if split logic fails
370
-
371
- # Add the chunk and its starting position
372
  chunks.append((text[start:actual_end], start))
373
-
374
- # Determine the start of the next chunk
375
- # Move forward by chunk_size minus overlap, ensuring progress
376
- next_start = start + (chunk_size - overlap)
377
-
378
- # If we split at whitespace (actual_end), we can potentially start the next chunk
379
- # right after the split point to avoid redundant processing of the overlap zone
380
- # if the split was significantly before the ideal 'end'.
381
- # However, the simple `next_start = start + (chunk_size - overlap)` is safer
382
- # to ensure consistent overlap handling unless more complex logic is added.
383
- # Let's stick to the simpler approach for now:
384
- # next_start = actual_end - overlap # This could lead to variable overlap size
385
-
386
- # Ensure we always make progress
387
  if next_start <= start:
388
- logger.warning("Chunking logic resulted in no progress. Moving start by 1.")
389
  next_start = start + 1
390
-
391
-
392
  start = next_start
393
 
394
-
395
  return chunks
 
59
  # Check if the model is a PyTorch model for potential optimizations
60
  if isinstance(self._model, torch.nn.Module):
61
  self._model.eval() # Set model to evaluation mode (important for inference)
 
62
 
63
  # Create the NER pipeline
 
 
64
  self._pipeline = pipeline(
65
  "ner",
66
  model=self._model,
67
  tokenizer=self._tokenizer,
68
+ # grouped_entities=True # Uncomment if you want to use pipeline's built-in grouping
69
  )
70
 
71
  load_time = time.time() - start_time
 
89
 
90
  Returns:
91
  A list of dictionaries, where each dictionary represents an entity
92
+ identified by the pipeline.
 
93
  """
94
  if self._pipeline is None:
95
  logger.error("NER pipeline is not initialized. Cannot predict.")
 
156
  entity_type = entity_tag[2:]
157
  else:
158
  # Handle cases where the tag might not have B-/I- prefix (less common)
 
159
  entity_type = entity_tag
160
 
161
  # Start of a new entity ('B-') or continuation of a different entity type
 
183
 
184
  # Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
185
  else:
186
+ logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.")
187
  if current_entity:
188
  combined_entities.append(current_entity)
189
+ # Try to create a new entity from this token
190
+ current_entity = {
191
+ 'entity_type': entity_type,
192
+ 'start': token['start'],
193
+ 'end': token['end'],
194
+ 'score': float(token['score']),
195
+ 'token_count': 1
196
+ }
197
 
198
  # Add the last tracked entity if it exists
199
  if current_entity:
 
202
  # Extract the actual text 'word' for each combined entity
203
  for entity in combined_entities:
204
  try:
205
+ # Ensure indices are valid
206
+ start = max(0, min(entity['start'], len(original_text)))
207
+ end = max(start, min(entity['end'], len(original_text)))
208
+ entity['word'] = original_text[start:end].strip()
209
  # Remove internal helper key
210
  if 'token_count' in entity:
211
  del entity['token_count']
212
+ except Exception as e:
213
+ logger.error(f"Error extracting word for entity: {entity}, error: {e}")
214
  entity['word'] = "[Error extracting word]"
215
 
216
+ # Sort entities by start position
 
217
  combined_entities.sort(key=lambda x: x['start'])
218
 
219
  logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
 
228
  Args:
229
  text: The large input text string.
230
  model: The initialized NERModel instance.
231
+ chunk_size: The maximum size of each text chunk.
232
+ overlap: The number of characters to overlap between consecutive chunks.
 
 
233
 
234
  Returns:
235
  A list of combined entity dictionaries for the entire text.
 
248
  logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
249
  overlap = chunk_size // 4
250
 
 
251
  logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
252
  chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
253
  logger.info(f"Split text into {len(chunks)} chunks.")
 
266
  total_processing_time += chunk_processing_time
267
  logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
268
 
 
269
  # Adjust entity positions relative to the original text
270
  for result in raw_results_chunk:
271
  # Check if 'start' and 'end' exist before adjusting
 
275
  else:
276
  logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
277
 
 
278
  all_raw_results.extend(raw_results_chunk)
279
 
280
  logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
281
  logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
282
 
283
+ # Combine entities from all chunks
 
 
 
 
 
284
  combined_entities = TextProcessor.combine_entities(all_raw_results, text)
285
 
286
+ # Deduplicate entities based on overlapping positions
287
+ # Two entities are considered duplicates if they have the same type and
288
+ # overlap by more than 50% of the shorter entity's length
289
  unique_entities = []
 
290
  for entity in combined_entities:
291
+ is_duplicate = False
292
+ # Calculate entity length for overlap comparison
293
+ entity_length = entity['end'] - entity['start']
294
+
295
+ for existing in unique_entities:
296
+ if existing['entity_type'] == entity['entity_type']:
297
+ # Check for significant overlap
298
+ overlap_start = max(entity['start'], existing['start'])
299
+ overlap_end = min(entity['end'], existing['end'])
300
+ if overlap_start < overlap_end: # They overlap
301
+ overlap_length = overlap_end - overlap_start
302
+ shorter_length = min(entity_length, existing['end'] - existing['start'])
303
+
304
+ # If overlap is significant (>50% of shorter entity)
305
+ if overlap_length > 0.5 * shorter_length:
306
+ is_duplicate = True
307
+ # Keep the one with higher score
308
+ if entity['score'] > existing['score']:
309
+ # Replace the existing entity with this one
310
+ unique_entities.remove(existing)
311
+ is_duplicate = False
312
+ break
313
+
314
+ if not is_duplicate:
315
  unique_entities.append(entity)
 
 
 
316
 
317
  logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
318
  return unique_entities
319
 
 
320
  @staticmethod
321
  def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
322
  """
 
337
  if chunk_size <= 0:
338
  raise ValueError("chunk_size must be positive")
339
 
 
340
  chunks = []
341
  start = 0
342
  text_len = len(text)
343
 
344
  while start < text_len:
345
  # Determine the ideal end position
346
+ end = min(start + chunk_size, text_len)
347
+
348
+ # If we're at the end of the text, just use what's left
349
  if end >= text_len:
350
  chunks.append((text[start:], start))
351
+ break
352
 
353
+ # Try to find a suitable split point (whitespace) to ensure we don't cut words
 
354
  split_pos = -1
355
+ # Search backwards from end to find a whitespace
356
+ for i in range(end, max(start, end - overlap) - 1, -1):
357
+ if i < text_len and text[i].isspace():
358
+ split_pos = i + 1 # Position after the space
359
+ break
360
+
361
+ # If no good split found, just use the calculated end
 
 
 
 
 
 
 
362
  if split_pos == -1 or split_pos <= start:
363
+ actual_end = end
 
364
  else:
365
+ actual_end = split_pos
366
+
367
+ # Add the chunk
 
 
 
 
 
 
368
  chunks.append((text[start:actual_end], start))
369
+
370
+ # Calculate next start position, ensuring we make progress
371
+ next_start = start + (actual_end - start - overlap)
 
 
 
 
 
 
 
 
 
 
 
372
  if next_start <= start:
 
373
  next_start = start + 1
374
+
 
375
  start = next_start
376
 
 
377
  return chunks