Abhigyan
commited on
Commit
·
f68c4f8
1
Parent(s):
3bdd5ce
Refactor
Browse files- __pycache__/ner_module.cpython-310.pyc +0 -0
- app.py +184 -381
- ner_module.py +68 -86
__pycache__/ner_module.cpython-310.pyc
ADDED
Binary file (9.93 kB). View file
|
|
app.py
CHANGED
@@ -1,396 +1,199 @@
|
|
1 |
-
#
|
2 |
-
import
|
|
|
3 |
import time
|
4 |
-
from typing import List, Dict, Any, Tuple
|
5 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
6 |
import logging
|
7 |
|
8 |
-
# Configure logging
|
9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"""
|
14 |
-
|
15 |
-
|
16 |
"""
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
if NERModel._instance is not None:
|
42 |
-
raise Exception("This class is a singleton! Use get_instance() to get the object.")
|
43 |
-
else:
|
44 |
-
self.model_name = model_name
|
45 |
-
NERModel._model_name = model_name # Store the model name
|
46 |
-
self._load_model()
|
47 |
-
NERModel._instance = self # Assign the instance here
|
48 |
-
|
49 |
-
def _load_model(self):
|
50 |
-
"""Load the NER model and tokenizer from Hugging Face."""
|
51 |
-
logger.info(f"Loading model: {self.model_name}")
|
52 |
-
start_time = time.time()
|
53 |
-
|
54 |
-
try:
|
55 |
-
# Load tokenizer and model
|
56 |
-
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
57 |
-
self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
|
58 |
-
|
59 |
-
# Check if the model is a PyTorch model for potential optimizations
|
60 |
-
if isinstance(self._model, torch.nn.Module):
|
61 |
-
self._model.eval() # Set model to evaluation mode (important for inference)
|
62 |
-
# self._model.share_memory() # share_memory() might not be needed unless using multiprocessing explicitly
|
63 |
-
|
64 |
-
# Create the NER pipeline
|
65 |
-
# Specify device=-1 for CPU, device=0 for first GPU, etc.
|
66 |
-
# Let pipeline decide device automatically by default, or specify if needed
|
67 |
-
self._pipeline = pipeline(
|
68 |
-
"ner",
|
69 |
-
model=self._model,
|
70 |
-
tokenizer=self._tokenizer,
|
71 |
-
# grouped_entities=True # Group subword tokens automatically (alternative to manual combination)
|
72 |
-
)
|
73 |
-
|
74 |
-
load_time = time.time() - start_time
|
75 |
-
logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")
|
76 |
-
|
77 |
-
except Exception as e:
|
78 |
-
logger.error(f"Error loading model {self.model_name}: {e}")
|
79 |
-
# Clean up partial loads if necessary
|
80 |
-
self._tokenizer = None
|
81 |
-
self._model = None
|
82 |
-
self._pipeline = None
|
83 |
-
# Re-raise the exception to signal failure
|
84 |
-
raise
|
85 |
-
|
86 |
-
def predict(self, text: str) -> List[Dict[str, Any]]:
|
87 |
-
"""
|
88 |
-
Run NER prediction on the input text using the loaded pipeline.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
text: The input string to perform NER on.
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
A list of dictionaries, where each dictionary represents an entity
|
95 |
-
identified by the pipeline. The exact format depends on the pipeline
|
96 |
-
configuration (e.g., grouped_entities).
|
97 |
-
"""
|
98 |
-
if self._pipeline is None:
|
99 |
-
logger.error("NER pipeline is not initialized. Cannot predict.")
|
100 |
-
return [] # Return empty list or raise an error
|
101 |
-
|
102 |
-
if not text or not isinstance(text, str):
|
103 |
-
logger.warning("Prediction called with empty or invalid text.")
|
104 |
-
return []
|
105 |
-
|
106 |
-
logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
|
107 |
-
try:
|
108 |
-
# The pipeline handles tokenization and prediction
|
109 |
-
results = self._pipeline(text)
|
110 |
-
logger.debug(f"Prediction results: {results}")
|
111 |
-
return results
|
112 |
-
except Exception as e:
|
113 |
-
logger.error(f"Error during NER prediction: {e}")
|
114 |
-
return [] # Return empty list on error
|
115 |
-
|
116 |
-
|
117 |
-
class TextProcessor:
|
118 |
"""
|
119 |
-
|
120 |
-
|
121 |
"""
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
except IndexError:
|
209 |
-
logger.error(f"Index error extracting word for entity: {entity} with text length {len(original_text)}")
|
210 |
-
entity['word'] = "[Error extracting word]"
|
211 |
-
|
212 |
-
|
213 |
-
# Optional: Sort entities by start position
|
214 |
-
combined_entities.sort(key=lambda x: x['start'])
|
215 |
-
|
216 |
-
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
|
217 |
-
return combined_entities
|
218 |
-
|
219 |
-
@staticmethod
|
220 |
-
def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
|
221 |
-
"""
|
222 |
-
Process large text by splitting it into overlapping chunks, running NER
|
223 |
-
on each chunk, and then combining the results intelligently.
|
224 |
-
|
225 |
-
Args:
|
226 |
-
text: The large input text string.
|
227 |
-
model: The initialized NERModel instance.
|
228 |
-
chunk_size: The maximum size of each text chunk (in characters or tokens,
|
229 |
-
depending on the tokenizer's limits, often ~512 for BERT).
|
230 |
-
overlap: The number of characters/tokens to overlap between consecutive chunks
|
231 |
-
to ensure entities spanning chunk boundaries are captured.
|
232 |
-
|
233 |
-
Returns:
|
234 |
-
A list of combined entity dictionaries for the entire text.
|
235 |
-
"""
|
236 |
-
if not text:
|
237 |
-
return []
|
238 |
-
|
239 |
-
# Use tokenizer max length if available and smaller than chunk_size
|
240 |
-
if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
|
241 |
-
tokenizer_max_len = model._tokenizer.model_max_length
|
242 |
-
if chunk_size > tokenizer_max_len:
|
243 |
-
logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
|
244 |
-
chunk_size = tokenizer_max_len
|
245 |
-
# Ensure overlap is reasonable compared to chunk size
|
246 |
-
if overlap >= chunk_size // 2:
|
247 |
-
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
|
248 |
-
overlap = chunk_size // 4
|
249 |
-
|
250 |
-
|
251 |
-
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
|
252 |
-
chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
|
253 |
-
logger.info(f"Split text into {len(chunks)} chunks.")
|
254 |
-
|
255 |
-
all_raw_results = []
|
256 |
-
total_processing_time = 0
|
257 |
-
|
258 |
-
for i, (chunk_text, start_pos) in enumerate(chunks):
|
259 |
-
logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
|
260 |
-
start_time = time.time()
|
261 |
-
|
262 |
-
# Get raw predictions for the current chunk
|
263 |
-
raw_results_chunk = model.predict(chunk_text)
|
264 |
-
|
265 |
-
chunk_processing_time = time.time() - start_time
|
266 |
-
total_processing_time += chunk_processing_time
|
267 |
-
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
|
268 |
-
|
269 |
-
|
270 |
-
# Adjust entity positions relative to the original text
|
271 |
-
for result in raw_results_chunk:
|
272 |
-
# Check if 'start' and 'end' exist before adjusting
|
273 |
-
if 'start' in result and 'end' in result:
|
274 |
-
result['start'] += start_pos
|
275 |
-
result['end'] += start_pos
|
276 |
-
else:
|
277 |
-
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
|
278 |
-
|
279 |
-
|
280 |
-
all_raw_results.extend(raw_results_chunk)
|
281 |
-
|
282 |
-
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
|
283 |
-
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
|
284 |
-
|
285 |
-
# Combine entities from all chunks, handling potential duplicates from overlap
|
286 |
-
# The combine_entities method needs refinement to handle overlaps better,
|
287 |
-
# e.g., by prioritizing entities from non-overlapped regions or merging based on confidence.
|
288 |
-
# For now, we use the existing combine_entities, which might create duplicates if
|
289 |
-
# an entity appears fully in the overlap region of two chunks.
|
290 |
-
# A more robust approach would involve deduplication based on start/end/type.
|
291 |
-
combined_entities = TextProcessor.combine_entities(all_raw_results, text)
|
292 |
-
|
293 |
-
# Simple deduplication based on exact start, end, and type
|
294 |
-
unique_entities = []
|
295 |
-
seen_entities = set()
|
296 |
-
for entity in combined_entities:
|
297 |
-
entity_key = (entity['start'], entity['end'], entity['entity_type'])
|
298 |
-
if entity_key not in seen_entities:
|
299 |
-
unique_entities.append(entity)
|
300 |
-
seen_entities.add(entity_key)
|
301 |
else:
|
302 |
-
logger.
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
#
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
# Try to find a suitable split point (e.g., whitespace) near the ideal end
|
343 |
-
# Search backwards from the ideal end position within a reasonable window (e.g., overlap size)
|
344 |
-
split_pos = -1
|
345 |
-
search_start = max(start, end - overlap) # Don't search too far back
|
346 |
-
for i in range(end, search_start -1 , -1):
|
347 |
-
# Prefer splitting at whitespace
|
348 |
-
if text[i].isspace():
|
349 |
-
split_pos = i + 1 # Split *after* the space
|
350 |
-
break
|
351 |
-
# Consider splitting at punctuation as a fallback? (optional)
|
352 |
-
# import string
|
353 |
-
# if text[i] in string.punctuation:
|
354 |
-
# split_pos = i + 1
|
355 |
-
# break
|
356 |
-
|
357 |
-
|
358 |
-
# If no good split point found nearby, just cut at the chunk_size
|
359 |
-
if split_pos == -1 or split_pos <= start:
|
360 |
-
actual_end = end
|
361 |
-
logger.debug(f"No suitable whitespace found near char {end}, cutting at {actual_end}")
|
362 |
-
else:
|
363 |
-
actual_end = split_pos
|
364 |
-
logger.debug(f"Found whitespace split point at char {actual_end}")
|
365 |
-
|
366 |
-
|
367 |
-
# Ensure the chunk isn't empty if split_pos was too close to start
|
368 |
-
if actual_end <= start:
|
369 |
-
actual_end = end # Fallback to hard cut if split logic fails
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
# Determine the start of the next chunk
|
375 |
-
# Move forward by chunk_size minus overlap, ensuring progress
|
376 |
-
next_start = start + (chunk_size - overlap)
|
377 |
-
|
378 |
-
# If we split at whitespace (actual_end), we can potentially start the next chunk
|
379 |
-
# right after the split point to avoid redundant processing of the overlap zone
|
380 |
-
# if the split was significantly before the ideal 'end'.
|
381 |
-
# However, the simple `next_start = start + (chunk_size - overlap)` is safer
|
382 |
-
# to ensure consistent overlap handling unless more complex logic is added.
|
383 |
-
# Let's stick to the simpler approach for now:
|
384 |
-
# next_start = actual_end - overlap # This could lead to variable overlap size
|
385 |
-
|
386 |
-
# Ensure we always make progress
|
387 |
-
if next_start <= start:
|
388 |
-
logger.warning("Chunking logic resulted in no progress. Moving start by 1.")
|
389 |
-
next_start = start + 1
|
390 |
-
|
391 |
-
|
392 |
-
start = next_start
|
393 |
|
|
|
|
|
394 |
|
395 |
-
|
|
|
|
|
396 |
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import streamlit as st
|
3 |
+
from ner_module import NERModel, TextProcessor
|
4 |
import time
|
|
|
|
|
5 |
import logging
|
6 |
|
7 |
+
# Configure logging (optional, but helpful for debugging Streamlit apps)
|
8 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
+
# --- Configuration ---
|
12 |
+
DEFAULT_MODEL = "Davlan/bert-base-multilingual-cased-ner-hrl"
|
13 |
+
# Alternative models (ensure they are compatible TokenClassification models)
|
14 |
+
# DEFAULT_MODEL = "dslim/bert-base-NER" # English NER
|
15 |
+
# DEFAULT_MODEL = "xlm-roberta-large-finetuned-conll03-english" # Another English option
|
16 |
+
|
17 |
+
DEFAULT_TEXT = """
|
18 |
+
Angela Merkel met Emmanuel Macron in Berlin on Tuesday to discuss the future of the European Union.
|
19 |
+
They visited the Brandenburg Gate and enjoyed some Currywurst. Later, they flew to Paris.
|
20 |
+
John Doe from New York works at Google LLC.
|
21 |
+
"""
|
22 |
+
CHUNK_SIZE_DEFAULT = 500 # Slightly less than common 512 limit to be safe
|
23 |
+
OVERLAP_DEFAULT = 50
|
24 |
+
|
25 |
+
# --- Caching ---
|
26 |
+
@st.cache_resource(show_spinner="Loading NER Model...")
|
27 |
+
def load_ner_model(model_name: str):
|
28 |
"""
|
29 |
+
Loads the NERModel using the singleton pattern and caches the instance.
|
30 |
+
Streamlit's cache_resource is ideal for heavy objects like models.
|
31 |
"""
|
32 |
+
try:
|
33 |
+
logger.info(f"Attempting to load model: {model_name}")
|
34 |
+
model_instance = NERModel.get_instance(model_name=model_name)
|
35 |
+
return model_instance
|
36 |
+
except Exception as e:
|
37 |
+
st.error(f"Failed to load model '{model_name}'. Error: {e}", icon="🚨")
|
38 |
+
logger.error(f"Fatal error loading model {model_name}: {e}")
|
39 |
+
return None
|
40 |
+
|
41 |
+
# --- Helper Functions ---
|
42 |
+
def get_color_for_entity(entity_type: str) -> str:
|
43 |
+
"""Assigns a color based on the entity type for visualization."""
|
44 |
+
# Simple color mapping, can be expanded
|
45 |
+
colors = {
|
46 |
+
"PER": "#faa", # Light red for Person
|
47 |
+
"ORG": "#afa", # Light green for Organization
|
48 |
+
"LOC": "#aaf", # Light blue for Location
|
49 |
+
"MISC": "#ffc", # Light yellow for Miscellaneous
|
50 |
+
# Add more colors as needed based on model's entity types
|
51 |
+
}
|
52 |
+
# Default color if type not found
|
53 |
+
return colors.get(entity_type.upper(), "#ddd") # Light grey default
|
54 |
+
|
55 |
+
def highlight_entities(text: str, entities: list) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
"""
|
57 |
+
Generates an HTML string with entities highlighted using spans and colors.
|
58 |
+
Sorts entities by start position descending to handle nested entities correctly.
|
59 |
"""
|
60 |
+
if not entities:
|
61 |
+
return text
|
62 |
+
|
63 |
+
# Sort entities by start index in descending order
|
64 |
+
# This ensures that inner entities are processed before outer ones if they overlap
|
65 |
+
entities.sort(key=lambda x: x['start'], reverse=True)
|
66 |
+
|
67 |
+
highlighted_text = text
|
68 |
+
for entity in entities:
|
69 |
+
start = entity['start']
|
70 |
+
end = entity['end']
|
71 |
+
entity_type = entity['entity_type']
|
72 |
+
word = entity['word'] # Use the extracted word for the title/tooltip
|
73 |
+
color = get_color_for_entity(entity_type)
|
74 |
+
|
75 |
+
# Create the highlighted span
|
76 |
+
highlight = (
|
77 |
+
f'<span style="background-color: {color}; padding: 0.2em 0.3em; '
|
78 |
+
f'margin: 0 0.15em; line-height: 1; border-radius: 0.3em;" '
|
79 |
+
f'title="{entity_type}: {word} (Score: {entity.get("score", 0):.2f})">' # Tooltip
|
80 |
+
f'{highlighted_text[start:end]}' # Get the original text slice
|
81 |
+
f'<sup style="font-size: 0.7em; font-weight: bold; margin-left: 2px; color: #555;">{entity_type}</sup>' # Small label
|
82 |
+
f'</span>'
|
83 |
+
)
|
84 |
+
|
85 |
+
# Replace the original text portion with the highlighted version
|
86 |
+
# Working backwards prevents index issues from altering string length
|
87 |
+
highlighted_text = highlighted_text[:start] + highlight + highlighted_text[end:]
|
88 |
+
|
89 |
+
return highlighted_text
|
90 |
+
|
91 |
+
|
92 |
+
# --- Streamlit App UI ---
|
93 |
+
st.set_page_config(layout="wide", page_title="NER Demo")
|
94 |
+
|
95 |
+
st.title("📝 Named Entity Recognition (NER) Demo")
|
96 |
+
st.markdown("Highlight Persons (PER), Organizations (ORG), Locations (LOC), and Miscellaneous (MISC) entities in text using a Hugging Face Transformer model.")
|
97 |
+
|
98 |
+
# Model selection fixed to default for simplicity
|
99 |
+
model_name = DEFAULT_MODEL
|
100 |
+
|
101 |
+
# Load the model (cached)
|
102 |
+
ner_model = load_ner_model(model_name)
|
103 |
+
|
104 |
+
if ner_model: # Proceed only if the model loaded successfully
|
105 |
+
st.success(f"Model '{ner_model.model_name}' loaded successfully.", icon="✅")
|
106 |
+
|
107 |
+
# --- Input & Controls ---
|
108 |
+
col1, col2 = st.columns([3, 1]) # Input area takes more space
|
109 |
+
|
110 |
+
with col1:
|
111 |
+
st.subheader("Input Text")
|
112 |
+
# Use session state to keep text area content persistent across reruns
|
113 |
+
if 'text_input' not in st.session_state:
|
114 |
+
st.session_state.text_input = DEFAULT_TEXT
|
115 |
+
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, height=250, key="text_area_input")
|
116 |
+
st.session_state.text_input = text_input # Update session state on change
|
117 |
+
|
118 |
+
with col2:
|
119 |
+
st.subheader("Options")
|
120 |
+
use_chunking = st.checkbox("Process as Large Text (Chunking)", value=True)
|
121 |
+
|
122 |
+
chunk_size = CHUNK_SIZE_DEFAULT
|
123 |
+
overlap = OVERLAP_DEFAULT
|
124 |
+
|
125 |
+
if use_chunking:
|
126 |
+
chunk_size = st.slider("Chunk Size (chars)", min_value=100, max_value=1024, value=CHUNK_SIZE_DEFAULT, step=10)
|
127 |
+
overlap = st.slider("Overlap (chars)", min_value=10, max_value=chunk_size // 2, value=OVERLAP_DEFAULT, step=5)
|
128 |
+
|
129 |
+
process_button = st.button("✨ Analyze Text", type="primary", use_container_width=True)
|
130 |
+
|
131 |
+
# --- Processing and Output ---
|
132 |
+
if process_button and text_input:
|
133 |
+
start_process_time = time.time()
|
134 |
+
st.markdown("---") # Separator
|
135 |
+
st.subheader("Analysis Results")
|
136 |
+
|
137 |
+
with st.spinner("Analyzing text... Please wait."):
|
138 |
+
if use_chunking:
|
139 |
+
logger.info(f"Processing with chunking: size={chunk_size}, overlap={overlap}")
|
140 |
+
entities = TextProcessor.process_large_text(
|
141 |
+
text=text_input,
|
142 |
+
model=ner_model,
|
143 |
+
chunk_size=chunk_size,
|
144 |
+
overlap=overlap
|
145 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
else:
|
147 |
+
logger.info("Processing without chunking (potential truncation for long text)")
|
148 |
+
entities = TextProcessor.process_large_text(
|
149 |
+
text=text_input,
|
150 |
+
model=ner_model,
|
151 |
+
chunk_size=max(len(text_input), 512), # Use text length or a large value
|
152 |
+
overlap=0 # No overlap needed for single chunk
|
153 |
+
)
|
154 |
+
|
155 |
+
end_process_time = time.time()
|
156 |
+
processing_duration = end_process_time - start_process_time
|
157 |
+
st.info(f"Analysis completed in {processing_duration:.2f} seconds. Found {len(entities)} entities.", icon="⏱️")
|
158 |
+
|
159 |
+
if entities:
|
160 |
+
# Display highlighted text
|
161 |
+
st.markdown("#### Highlighted Text:")
|
162 |
+
highlighted_html = highlight_entities(text_input, entities)
|
163 |
+
# Use st.markdown to render the HTML
|
164 |
+
st.markdown(highlighted_html, unsafe_allow_html=True)
|
165 |
+
|
166 |
+
# Display entities in a table-like format
|
167 |
+
st.markdown("#### Extracted Entities:")
|
168 |
+
# Sort entities by appearance order for the list
|
169 |
+
entities.sort(key=lambda x: x['start'])
|
170 |
+
|
171 |
+
# Use columns for a cleaner layout
|
172 |
+
cols = st.columns(3) # Adjust number of columns as needed
|
173 |
+
col_idx = 0
|
174 |
+
for entity in entities:
|
175 |
+
with cols[col_idx % len(cols)]:
|
176 |
+
st.markdown(
|
177 |
+
f"**{entity['entity_type']}** `{entity['score']:.2f}`: "
|
178 |
+
f"{entity['word']} ({entity['start']}-{entity['end']})"
|
179 |
+
)
|
180 |
+
col_idx += 1
|
181 |
+
|
182 |
+
# Alternative display as an expander with detailed info
|
183 |
+
with st.expander("Show Detailed Entity List", expanded=False):
|
184 |
+
for entity in entities:
|
185 |
+
st.write(f"- **{entity['entity_type']}**: {entity['word']} (Score: {entity['score']:.2f}, Position: {entity['start']}-{entity['end']})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
else:
|
188 |
+
st.warning("No entities found in the provided text.", icon="❓")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
elif process_button and not text_input:
|
191 |
+
st.warning("Please enter some text to analyze.", icon="⚠️")
|
192 |
|
193 |
+
else:
|
194 |
+
# This block runs if the model failed to load
|
195 |
+
st.error("NER model could not be loaded. Please check the logs or model name. The application cannot proceed.", icon="🛑")
|
196 |
|
197 |
+
# Add footer or instructions
|
198 |
+
st.markdown("---")
|
199 |
+
st.caption("Powered by Hugging Face Transformers and Streamlit.")
|
ner_module.py
CHANGED
@@ -59,16 +59,13 @@ class NERModel:
|
|
59 |
# Check if the model is a PyTorch model for potential optimizations
|
60 |
if isinstance(self._model, torch.nn.Module):
|
61 |
self._model.eval() # Set model to evaluation mode (important for inference)
|
62 |
-
# self._model.share_memory() # share_memory() might not be needed unless using multiprocessing explicitly
|
63 |
|
64 |
# Create the NER pipeline
|
65 |
-
# Specify device=-1 for CPU, device=0 for first GPU, etc.
|
66 |
-
# Let pipeline decide device automatically by default, or specify if needed
|
67 |
self._pipeline = pipeline(
|
68 |
"ner",
|
69 |
model=self._model,
|
70 |
tokenizer=self._tokenizer,
|
71 |
-
# grouped_entities=True #
|
72 |
)
|
73 |
|
74 |
load_time = time.time() - start_time
|
@@ -92,8 +89,7 @@ class NERModel:
|
|
92 |
|
93 |
Returns:
|
94 |
A list of dictionaries, where each dictionary represents an entity
|
95 |
-
identified by the pipeline.
|
96 |
-
configuration (e.g., grouped_entities).
|
97 |
"""
|
98 |
if self._pipeline is None:
|
99 |
logger.error("NER pipeline is not initialized. Cannot predict.")
|
@@ -160,7 +156,6 @@ class TextProcessor:
|
|
160 |
entity_type = entity_tag[2:]
|
161 |
else:
|
162 |
# Handle cases where the tag might not have B-/I- prefix (less common)
|
163 |
-
logger.warning(f"Unexpected entity tag format: {entity_tag}. Using as is.")
|
164 |
entity_type = entity_tag
|
165 |
|
166 |
# Start of a new entity ('B-') or continuation of a different entity type
|
@@ -188,11 +183,17 @@ class TextProcessor:
|
|
188 |
|
189 |
# Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
|
190 |
else:
|
191 |
-
logger.warning(f"Encountered unexpected token sequence at: {token}.
|
192 |
if current_entity:
|
193 |
combined_entities.append(current_entity)
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
# Add the last tracked entity if it exists
|
198 |
if current_entity:
|
@@ -201,16 +202,18 @@ class TextProcessor:
|
|
201 |
# Extract the actual text 'word' for each combined entity
|
202 |
for entity in combined_entities:
|
203 |
try:
|
204 |
-
|
|
|
|
|
|
|
205 |
# Remove internal helper key
|
206 |
if 'token_count' in entity:
|
207 |
del entity['token_count']
|
208 |
-
except
|
209 |
-
logger.error(f"
|
210 |
entity['word'] = "[Error extracting word]"
|
211 |
|
212 |
-
|
213 |
-
# Optional: Sort entities by start position
|
214 |
combined_entities.sort(key=lambda x: x['start'])
|
215 |
|
216 |
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
|
@@ -225,10 +228,8 @@ class TextProcessor:
|
|
225 |
Args:
|
226 |
text: The large input text string.
|
227 |
model: The initialized NERModel instance.
|
228 |
-
chunk_size: The maximum size of each text chunk
|
229 |
-
|
230 |
-
overlap: The number of characters/tokens to overlap between consecutive chunks
|
231 |
-
to ensure entities spanning chunk boundaries are captured.
|
232 |
|
233 |
Returns:
|
234 |
A list of combined entity dictionaries for the entire text.
|
@@ -247,7 +248,6 @@ class TextProcessor:
|
|
247 |
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
|
248 |
overlap = chunk_size // 4
|
249 |
|
250 |
-
|
251 |
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
|
252 |
chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
|
253 |
logger.info(f"Split text into {len(chunks)} chunks.")
|
@@ -266,7 +266,6 @@ class TextProcessor:
|
|
266 |
total_processing_time += chunk_processing_time
|
267 |
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
|
268 |
|
269 |
-
|
270 |
# Adjust entity positions relative to the original text
|
271 |
for result in raw_results_chunk:
|
272 |
# Check if 'start' and 'end' exist before adjusting
|
@@ -276,35 +275,48 @@ class TextProcessor:
|
|
276 |
else:
|
277 |
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
|
278 |
|
279 |
-
|
280 |
all_raw_results.extend(raw_results_chunk)
|
281 |
|
282 |
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
|
283 |
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
|
284 |
|
285 |
-
# Combine entities from all chunks
|
286 |
-
# The combine_entities method needs refinement to handle overlaps better,
|
287 |
-
# e.g., by prioritizing entities from non-overlapped regions or merging based on confidence.
|
288 |
-
# For now, we use the existing combine_entities, which might create duplicates if
|
289 |
-
# an entity appears fully in the overlap region of two chunks.
|
290 |
-
# A more robust approach would involve deduplication based on start/end/type.
|
291 |
combined_entities = TextProcessor.combine_entities(all_raw_results, text)
|
292 |
|
293 |
-
#
|
|
|
|
|
294 |
unique_entities = []
|
295 |
-
seen_entities = set()
|
296 |
for entity in combined_entities:
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
unique_entities.append(entity)
|
300 |
-
seen_entities.add(entity_key)
|
301 |
-
else:
|
302 |
-
logger.debug(f"Duplicate entity removed: {entity}")
|
303 |
|
304 |
logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
|
305 |
return unique_entities
|
306 |
|
307 |
-
|
308 |
@staticmethod
|
309 |
def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
|
310 |
"""
|
@@ -325,71 +337,41 @@ class TextProcessor:
|
|
325 |
if chunk_size <= 0:
|
326 |
raise ValueError("chunk_size must be positive")
|
327 |
|
328 |
-
|
329 |
chunks = []
|
330 |
start = 0
|
331 |
text_len = len(text)
|
332 |
|
333 |
while start < text_len:
|
334 |
# Determine the ideal end position
|
335 |
-
end = start + chunk_size
|
336 |
-
|
337 |
-
# If the
|
338 |
if end >= text_len:
|
339 |
chunks.append((text[start:], start))
|
340 |
-
break
|
341 |
|
342 |
-
# Try to find a suitable split point (
|
343 |
-
# Search backwards from the ideal end position within a reasonable window (e.g., overlap size)
|
344 |
split_pos = -1
|
345 |
-
|
346 |
-
for i in range(end,
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
# import string
|
353 |
-
# if text[i] in string.punctuation:
|
354 |
-
# split_pos = i + 1
|
355 |
-
# break
|
356 |
-
|
357 |
-
|
358 |
-
# If no good split point found nearby, just cut at the chunk_size
|
359 |
if split_pos == -1 or split_pos <= start:
|
360 |
-
|
361 |
-
logger.debug(f"No suitable whitespace found near char {end}, cutting at {actual_end}")
|
362 |
else:
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
# Ensure the chunk isn't empty if split_pos was too close to start
|
368 |
-
if actual_end <= start:
|
369 |
-
actual_end = end # Fallback to hard cut if split logic fails
|
370 |
-
|
371 |
-
# Add the chunk and its starting position
|
372 |
chunks.append((text[start:actual_end], start))
|
373 |
-
|
374 |
-
#
|
375 |
-
|
376 |
-
next_start = start + (chunk_size - overlap)
|
377 |
-
|
378 |
-
# If we split at whitespace (actual_end), we can potentially start the next chunk
|
379 |
-
# right after the split point to avoid redundant processing of the overlap zone
|
380 |
-
# if the split was significantly before the ideal 'end'.
|
381 |
-
# However, the simple `next_start = start + (chunk_size - overlap)` is safer
|
382 |
-
# to ensure consistent overlap handling unless more complex logic is added.
|
383 |
-
# Let's stick to the simpler approach for now:
|
384 |
-
# next_start = actual_end - overlap # This could lead to variable overlap size
|
385 |
-
|
386 |
-
# Ensure we always make progress
|
387 |
if next_start <= start:
|
388 |
-
logger.warning("Chunking logic resulted in no progress. Moving start by 1.")
|
389 |
next_start = start + 1
|
390 |
-
|
391 |
-
|
392 |
start = next_start
|
393 |
|
394 |
-
|
395 |
return chunks
|
|
|
59 |
# Check if the model is a PyTorch model for potential optimizations
|
60 |
if isinstance(self._model, torch.nn.Module):
|
61 |
self._model.eval() # Set model to evaluation mode (important for inference)
|
|
|
62 |
|
63 |
# Create the NER pipeline
|
|
|
|
|
64 |
self._pipeline = pipeline(
|
65 |
"ner",
|
66 |
model=self._model,
|
67 |
tokenizer=self._tokenizer,
|
68 |
+
# grouped_entities=True # Uncomment if you want to use pipeline's built-in grouping
|
69 |
)
|
70 |
|
71 |
load_time = time.time() - start_time
|
|
|
89 |
|
90 |
Returns:
|
91 |
A list of dictionaries, where each dictionary represents an entity
|
92 |
+
identified by the pipeline.
|
|
|
93 |
"""
|
94 |
if self._pipeline is None:
|
95 |
logger.error("NER pipeline is not initialized. Cannot predict.")
|
|
|
156 |
entity_type = entity_tag[2:]
|
157 |
else:
|
158 |
# Handle cases where the tag might not have B-/I- prefix (less common)
|
|
|
159 |
entity_type = entity_tag
|
160 |
|
161 |
# Start of a new entity ('B-') or continuation of a different entity type
|
|
|
183 |
|
184 |
# Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
|
185 |
else:
|
186 |
+
logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.")
|
187 |
if current_entity:
|
188 |
combined_entities.append(current_entity)
|
189 |
+
# Try to create a new entity from this token
|
190 |
+
current_entity = {
|
191 |
+
'entity_type': entity_type,
|
192 |
+
'start': token['start'],
|
193 |
+
'end': token['end'],
|
194 |
+
'score': float(token['score']),
|
195 |
+
'token_count': 1
|
196 |
+
}
|
197 |
|
198 |
# Add the last tracked entity if it exists
|
199 |
if current_entity:
|
|
|
202 |
# Extract the actual text 'word' for each combined entity
|
203 |
for entity in combined_entities:
|
204 |
try:
|
205 |
+
# Ensure indices are valid
|
206 |
+
start = max(0, min(entity['start'], len(original_text)))
|
207 |
+
end = max(start, min(entity['end'], len(original_text)))
|
208 |
+
entity['word'] = original_text[start:end].strip()
|
209 |
# Remove internal helper key
|
210 |
if 'token_count' in entity:
|
211 |
del entity['token_count']
|
212 |
+
except Exception as e:
|
213 |
+
logger.error(f"Error extracting word for entity: {entity}, error: {e}")
|
214 |
entity['word'] = "[Error extracting word]"
|
215 |
|
216 |
+
# Sort entities by start position
|
|
|
217 |
combined_entities.sort(key=lambda x: x['start'])
|
218 |
|
219 |
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
|
|
|
228 |
Args:
|
229 |
text: The large input text string.
|
230 |
model: The initialized NERModel instance.
|
231 |
+
chunk_size: The maximum size of each text chunk.
|
232 |
+
overlap: The number of characters to overlap between consecutive chunks.
|
|
|
|
|
233 |
|
234 |
Returns:
|
235 |
A list of combined entity dictionaries for the entire text.
|
|
|
248 |
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
|
249 |
overlap = chunk_size // 4
|
250 |
|
|
|
251 |
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
|
252 |
chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
|
253 |
logger.info(f"Split text into {len(chunks)} chunks.")
|
|
|
266 |
total_processing_time += chunk_processing_time
|
267 |
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
|
268 |
|
|
|
269 |
# Adjust entity positions relative to the original text
|
270 |
for result in raw_results_chunk:
|
271 |
# Check if 'start' and 'end' exist before adjusting
|
|
|
275 |
else:
|
276 |
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
|
277 |
|
|
|
278 |
all_raw_results.extend(raw_results_chunk)
|
279 |
|
280 |
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
|
281 |
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
|
282 |
|
283 |
+
# Combine entities from all chunks
|
|
|
|
|
|
|
|
|
|
|
284 |
combined_entities = TextProcessor.combine_entities(all_raw_results, text)
|
285 |
|
286 |
+
# Deduplicate entities based on overlapping positions
|
287 |
+
# Two entities are considered duplicates if they have the same type and
|
288 |
+
# overlap by more than 50% of the shorter entity's length
|
289 |
unique_entities = []
|
|
|
290 |
for entity in combined_entities:
|
291 |
+
is_duplicate = False
|
292 |
+
# Calculate entity length for overlap comparison
|
293 |
+
entity_length = entity['end'] - entity['start']
|
294 |
+
|
295 |
+
for existing in unique_entities:
|
296 |
+
if existing['entity_type'] == entity['entity_type']:
|
297 |
+
# Check for significant overlap
|
298 |
+
overlap_start = max(entity['start'], existing['start'])
|
299 |
+
overlap_end = min(entity['end'], existing['end'])
|
300 |
+
if overlap_start < overlap_end: # They overlap
|
301 |
+
overlap_length = overlap_end - overlap_start
|
302 |
+
shorter_length = min(entity_length, existing['end'] - existing['start'])
|
303 |
+
|
304 |
+
# If overlap is significant (>50% of shorter entity)
|
305 |
+
if overlap_length > 0.5 * shorter_length:
|
306 |
+
is_duplicate = True
|
307 |
+
# Keep the one with higher score
|
308 |
+
if entity['score'] > existing['score']:
|
309 |
+
# Replace the existing entity with this one
|
310 |
+
unique_entities.remove(existing)
|
311 |
+
is_duplicate = False
|
312 |
+
break
|
313 |
+
|
314 |
+
if not is_duplicate:
|
315 |
unique_entities.append(entity)
|
|
|
|
|
|
|
316 |
|
317 |
logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
|
318 |
return unique_entities
|
319 |
|
|
|
320 |
@staticmethod
|
321 |
def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
|
322 |
"""
|
|
|
337 |
if chunk_size <= 0:
|
338 |
raise ValueError("chunk_size must be positive")
|
339 |
|
|
|
340 |
chunks = []
|
341 |
start = 0
|
342 |
text_len = len(text)
|
343 |
|
344 |
while start < text_len:
|
345 |
# Determine the ideal end position
|
346 |
+
end = min(start + chunk_size, text_len)
|
347 |
+
|
348 |
+
# If we're at the end of the text, just use what's left
|
349 |
if end >= text_len:
|
350 |
chunks.append((text[start:], start))
|
351 |
+
break
|
352 |
|
353 |
+
# Try to find a suitable split point (whitespace) to ensure we don't cut words
|
|
|
354 |
split_pos = -1
|
355 |
+
# Search backwards from end to find a whitespace
|
356 |
+
for i in range(end, max(start, end - overlap) - 1, -1):
|
357 |
+
if i < text_len and text[i].isspace():
|
358 |
+
split_pos = i + 1 # Position after the space
|
359 |
+
break
|
360 |
+
|
361 |
+
# If no good split found, just use the calculated end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
if split_pos == -1 or split_pos <= start:
|
363 |
+
actual_end = end
|
|
|
364 |
else:
|
365 |
+
actual_end = split_pos
|
366 |
+
|
367 |
+
# Add the chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
chunks.append((text[start:actual_end], start))
|
369 |
+
|
370 |
+
# Calculate next start position, ensuring we make progress
|
371 |
+
next_start = start + (actual_end - start - overlap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
if next_start <= start:
|
|
|
373 |
next_start = start + 1
|
374 |
+
|
|
|
375 |
start = next_start
|
376 |
|
|
|
377 |
return chunks
|