AnseMin commited on
Commit
b61bec6
·
1 Parent(s): 6d39d2f

complete reimplementation of got ocr

Browse files
Files changed (1) hide show
  1. src/parsers/got_ocr_parser.py +132 -464
src/parsers/got_ocr_parser.py CHANGED
@@ -1,27 +1,13 @@
1
  from pathlib import Path
2
  from typing import Dict, List, Optional, Any, Union
3
- import json
4
- import os
5
- import tempfile
6
  import logging
 
7
  import sys
8
- import importlib
9
 
10
- # Set PyTorch environment variables to force float16 instead of bfloat16
11
- os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
13
- os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
14
- os.environ["PYTORCH_DISPATCHER_DISABLE_TORCH_FUNCTION_AUTOGRAD_FALLBACK"] = "1" # Disable fallbacks that might use bfloat16
15
-
16
- # Add patch for bfloat16 at the module level
17
- if 'torch' in sys.modules:
18
- torch_module = sys.modules['torch']
19
- if hasattr(torch_module, 'bfloat16'):
20
- # Create a reference to the original bfloat16 function
21
- original_bfloat16 = torch_module.bfloat16
22
- # Replace it with float16
23
- torch_module.bfloat16 = torch_module.float16
24
- logger.info("Patched torch.bfloat16 to use torch.float16 instead")
25
 
26
  from src.parsers.parser_interface import DocumentParser
27
  from src.parsers.parser_registry import ParserRegistry
@@ -29,46 +15,10 @@ from src.parsers.parser_registry import ParserRegistry
29
  # Configure logging
30
  logger = logging.getLogger(__name__)
31
 
32
- # Global flag for NumPy availability
33
- NUMPY_AVAILABLE = False
34
- NUMPY_VERSION = None
35
-
36
- # Initialize torch as None in global scope to prevent reference errors
37
- torch = None
38
- GOT_AVAILABLE = False
39
-
40
- # Try to import NumPy
41
- try:
42
- import numpy as np
43
- NUMPY_AVAILABLE = True
44
- NUMPY_VERSION = np.__version__
45
- logger.info(f"NumPy version {NUMPY_VERSION} is available")
46
- except ImportError:
47
- NUMPY_AVAILABLE = False
48
- logger.error("NumPy is not available. This is required for GOT-OCR.")
49
-
50
- # Check if required packages are installed
51
- try:
52
- import torch as torch_module
53
- torch = torch_module # Assign to global variable
54
- import transformers
55
- from transformers import AutoModel, AutoTokenizer
56
-
57
- # Check if transformers version is compatible
58
- from packaging import version
59
- if version.parse(transformers.__version__) >= version.parse("4.48.0"):
60
- logger.warning(
61
- f"Transformers version {transformers.__version__} may not be compatible with GOT-OCR. "
62
- "Consider downgrading to version <4.48.0"
63
- )
64
-
65
- GOT_AVAILABLE = True and NUMPY_AVAILABLE
66
- except ImportError as e:
67
- GOT_AVAILABLE = False
68
- logger.warning(f"GOT-OCR dependencies not installed: {str(e)}. The parser will not be available.")
69
-
70
  class GotOcrParser(DocumentParser):
71
- """Parser implementation using GOT-OCR 2.0."""
 
 
72
 
73
  _model = None
74
  _tokenizer = None
@@ -97,486 +47,204 @@ class GotOcrParser(DocumentParser):
97
  return "GOT-OCR 2.0 parser for converting images to text (requires CUDA)"
98
 
99
  @classmethod
100
- def _load_model(cls):
101
- """Load the GOT-OCR model and tokenizer if not already loaded."""
102
- global NUMPY_AVAILABLE, torch
103
-
104
- if not NUMPY_AVAILABLE:
105
- raise ImportError("NumPy is not available. This is required for GOT-OCR.")
 
106
 
107
- if torch is None:
108
- raise ImportError("PyTorch is not available. This is required for GOT-OCR.")
 
109
 
 
 
 
 
 
 
 
 
110
  if cls._model is None or cls._tokenizer is None:
111
  try:
 
 
 
 
112
  logger.info("Loading GOT-OCR model and tokenizer...")
 
 
113
  cls._tokenizer = AutoTokenizer.from_pretrained(
114
- 'stepfun-ai/GOT-OCR2_0',
115
  trust_remote_code=True
116
  )
117
 
118
- # Determine device mapping based on CUDA availability
119
- if torch.cuda.is_available():
120
- logger.info("Using CUDA device for model loading")
121
- device_map = 'cuda'
122
  else:
123
- logger.warning("No GPU available, falling back to CPU (not recommended)")
124
- device_map = 'auto'
125
-
126
- # Set torch default dtype to float16 since the CUDA device doesn't support bfloat16
127
- logger.info("Setting default tensor type to float16")
128
- torch.set_default_tensor_type(torch.FloatTensor)
129
- torch.set_default_dtype(torch.float16)
130
-
131
- # Aggressively patch torch.autocast to always use float16
132
- original_autocast = torch.amp.autocast if hasattr(torch.amp, 'autocast') else None
133
-
134
- if original_autocast:
135
- def patched_autocast(*args, **kwargs):
136
- # Force dtype to float16
137
- kwargs['dtype'] = torch.float16
138
- return original_autocast(*args, **kwargs)
139
-
140
- torch.amp.autocast = patched_autocast
141
- logger.info("Patched torch.amp.autocast to always use float16")
142
-
143
- # Patch tensor casting methods for bfloat16
144
- if hasattr(torch, 'Tensor'):
145
- if hasattr(torch.Tensor, 'to'):
146
- original_to = torch.Tensor.to
147
- def patched_to(self, *args, **kwargs):
148
- # If the first arg is a dtype and it's bfloat16, replace with float16
149
- if args and args[0] == torch.bfloat16:
150
- logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
151
- args = list(args)
152
- args[0] = torch.float16
153
- args = tuple(args)
154
- # If dtype is specified in kwargs and it's bfloat16, replace with float16
155
- if kwargs.get('dtype') == torch.bfloat16:
156
- logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
157
- kwargs['dtype'] = torch.float16
158
- return original_to(self, *args, **kwargs)
159
-
160
- torch.Tensor.to = patched_to
161
- logger.info("Patched torch.Tensor.to method to prevent bfloat16 usage")
162
 
163
- # Load the model with explicit float16 dtype
164
  cls._model = AutoModel.from_pretrained(
165
- 'stepfun-ai/GOT-OCR2_0',
166
- trust_remote_code=True,
167
- low_cpu_mem_usage=True,
168
- device_map=device_map,
169
  use_safetensors=True,
170
- pad_token_id=cls._tokenizer.eos_token_id,
171
- torch_dtype=torch.float16 # Explicitly specify float16 dtype
172
  )
173
 
174
- # Ensure all model parameters are float16
175
- for param in cls._model.parameters():
176
- param.data = param.data.to(torch.float16)
177
-
178
- # Examine model internals to find any direct bfloat16 usage
179
- def find_and_patch_bfloat16_attributes(module, path=""):
180
- for name, child in module.named_children():
181
- child_path = f"{path}.{name}" if path else name
182
- # Check if any attribute contains "bfloat16" in its name
183
- for attr_name in dir(child):
184
- if "bfloat16" in attr_name.lower():
185
- try:
186
- # Try to get the attribute
187
- attr_value = getattr(child, attr_name)
188
- logger.warning(f"Found potential bfloat16 usage at {child_path}.{attr_name}")
189
- # Try to replace with float16 equivalent if it exists
190
- float16_attr_name = attr_name.replace("bfloat16", "float16").replace("bf16", "fp16")
191
- if hasattr(child, float16_attr_name):
192
- logger.info(f"Replacing {attr_name} with {float16_attr_name}")
193
- setattr(child, attr_name, getattr(child, float16_attr_name))
194
- except Exception as e:
195
- logger.error(f"Error examining attribute {attr_name}: {e}")
196
- # Recursively check child modules
197
- find_and_patch_bfloat16_attributes(child, child_path)
198
-
199
- # Apply the internal examination
200
- logger.info("Examining model for potential bfloat16 usage...")
201
- find_and_patch_bfloat16_attributes(cls._model)
202
 
203
- # Patch the model's chat method to use float16 instead of bfloat16
204
- logger.info("Patching model to use float16 instead of bfloat16")
205
- original_chat = cls._model.chat
206
-
207
- # Get the original signature to understand the proper parameter order
208
- import inspect
209
- try:
210
- original_sig = inspect.signature(original_chat)
211
- logger.info(f"Original chat method signature: {original_sig}")
212
- except Exception as e:
213
- logger.warning(f"Could not inspect original chat method: {e}")
214
-
215
- # Define a completely new patched chat method that avoids parameter conflicts
216
- def patched_chat(self, tokenizer, image_path, ocr_type, **kwargs):
217
- """A patched version of chat method that forces float16 precision"""
218
- logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
219
-
220
- # Force any bfloat16 tensors to float16
221
- if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
222
- torch.bfloat16 = torch.float16
223
- logger.info("Forcing torch.bfloat16 to be torch.float16 within chat method")
224
-
225
- # Set explicit autocast dtype
226
- if hasattr(torch.amp, 'autocast'):
227
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
228
- try:
229
- # Pass arguments correctly - without 'self' as first arg since original_chat is already bound
230
- return original_chat(tokenizer, image_path, ocr_type, **kwargs)
231
- except TypeError as e:
232
- logger.warning(f"First call approach failed: {e}, trying alternative approach")
233
- try:
234
- # Try passing image_path as string in case that's the issue
235
- return original_chat(tokenizer, str(image_path), ocr_type, **kwargs)
236
- except Exception as e2:
237
- logger.warning(f"Second call approach also failed: {e2}")
238
- # Fall back to original method with keyword args
239
- return original_chat(tokenizer, image_path, ocr_type=ocr_type, **kwargs)
240
- except RuntimeError as e:
241
- if "bfloat16" in str(e):
242
- logger.error(f"BFloat16 error encountered despite patching: {e}")
243
- # More aggressive handling
244
- if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
245
- logger.info("Attempting with torch.cuda.amp.autocast as last resort")
246
- with torch.cuda.amp.autocast(dtype=torch.float16):
247
- return original_chat(tokenizer, str(image_path), ocr_type, **kwargs)
248
- raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
249
- else:
250
- raise
251
- else:
252
- # If autocast is not available, try to manually ensure everything is float16
253
- try:
254
- # Direct call without 'self' as first arg
255
- return original_chat(tokenizer, image_path, ocr_type, **kwargs)
256
- except TypeError as e:
257
- logger.warning(f"Call without autocast failed: {e}, trying alternative approach")
258
- try:
259
- # Try passing image_path as string in case that's the issue
260
- return original_chat(tokenizer, str(image_path), ocr_type, **kwargs)
261
- except:
262
- # Fall back to keyword arguments
263
- return original_chat(tokenizer, image_path, ocr_type=ocr_type, **kwargs)
264
-
265
- # Apply the patch
266
- import types
267
- cls._model.chat = types.MethodType(patched_chat, cls._model)
268
-
269
- # Check if the model has a cast_to_bfloat16 method and override it
270
- if hasattr(cls._model, 'cast_to_bfloat16'):
271
- original_cast = cls._model.cast_to_bfloat16
272
- def patched_cast(self, *args, **kwargs):
273
- logger.info("Intercepted attempt to cast model to bfloat16, using float16 instead")
274
- # If the model has a cast_to_float16 method, use that instead
275
- if hasattr(self, 'cast_to_float16'):
276
- return self.cast_to_float16(*args, **kwargs)
277
- # Otherwise, cast all parameters manually
278
- for param in self.parameters():
279
- param.data = param.data.to(torch.float16)
280
- return self
281
-
282
- cls._model.cast_to_bfloat16 = types.MethodType(patched_cast, cls._model)
283
- logger.info("Patched model.cast_to_bfloat16 method")
284
-
285
- # Set model to evaluation mode
286
  if device_map == 'cuda':
287
- cls._model = cls._model.eval().cuda().half() # Explicitly cast to half precision (float16)
288
- else:
289
- cls._model = cls._model.eval().half() # Explicitly cast to half precision (float16)
290
 
291
- # Reset default dtype to float32 after model loading
292
- torch.set_default_dtype(torch.float32)
293
- torch.set_default_tensor_type(torch.FloatTensor)
294
-
295
  logger.info("GOT-OCR model loaded successfully")
 
296
  except Exception as e:
297
  cls._model = None
298
  cls._tokenizer = None
299
  logger.error(f"Failed to load GOT-OCR model: {str(e)}")
300
- raise RuntimeError(f"Failed to load GOT-OCR model: {str(e)}")
 
301
 
302
  @classmethod
303
  def release_model(cls):
304
  """Release the model from memory."""
305
- global torch
306
-
307
- if cls._model is not None:
308
- del cls._model
309
- cls._model = None
310
- if cls._tokenizer is not None:
311
- del cls._tokenizer
312
- cls._tokenizer = None
313
- if torch is not None and hasattr(torch, 'cuda') and hasattr(torch.cuda, 'empty_cache'):
314
- torch.cuda.empty_cache()
315
-
316
- logger.info("GOT-OCR model released from memory")
317
-
318
- def _try_install_numpy(self):
319
- """Attempt to install NumPy using pip."""
320
- global NUMPY_AVAILABLE, NUMPY_VERSION
321
-
322
- logger.warning("Attempting to install NumPy...")
323
  try:
324
- import subprocess
325
- # Try to install numpy with explicit version constraint for compatibility with torchvision
326
- result = subprocess.run(
327
- [sys.executable, "-m", "pip", "install", "-q", "numpy<2.0.0", "--no-cache-dir"],
328
- capture_output=True,
329
- text=True,
330
- check=True
331
- )
332
- logger.info(f"NumPy installation result: {result.stdout}")
333
 
334
- # Try to import numpy again
335
- importlib.invalidate_caches()
336
- import numpy as np
337
- importlib.reload(np)
338
 
339
- NUMPY_AVAILABLE = True
340
- NUMPY_VERSION = np.__version__
341
- logger.info(f"NumPy installed successfully: version {NUMPY_VERSION}")
342
- return True
343
- except Exception as e:
344
- logger.error(f"Failed to install NumPy: {str(e)}")
345
- if hasattr(e, 'stderr'):
346
- logger.error(f"Installation error output: {e.stderr}")
347
- return False
348
-
349
- def _try_install_torch(self):
350
- """Attempt to install PyTorch using pip."""
351
- global torch
352
-
353
- logger.warning("Attempting to install PyTorch...")
354
- try:
355
- import subprocess
356
- # Install PyTorch with version constraint as per the requirements
357
- result = subprocess.run(
358
- [sys.executable, "-m", "pip", "install", "-q", "torch==2.0.1", "torchvision==0.15.2", "--no-cache-dir"],
359
- capture_output=True,
360
- text=True,
361
- check=True
362
- )
363
- logger.info(f"PyTorch installation result: {result.stdout}")
364
 
365
- # Try to import torch again
366
- importlib.invalidate_caches()
367
- import torch as torch_module
368
- torch = torch_module
369
 
370
- logger.info(f"PyTorch installed successfully: version {torch.__version__}")
371
- return True
372
  except Exception as e:
373
- logger.error(f"Failed to install PyTorch: {str(e)}")
374
- if hasattr(e, 'stderr'):
375
- logger.error(f"Installation error output: {e.stderr}")
376
- return False
377
 
378
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
379
- """Parse a document using GOT-OCR 2.0."""
380
- global NUMPY_AVAILABLE, GOT_AVAILABLE, torch
381
-
382
- # Check NumPy availability and try to install if not available
383
- if not NUMPY_AVAILABLE:
384
- logger.warning("NumPy not available, attempting to install it...")
385
- if self._try_install_numpy():
386
- # NumPy is now available
387
- logger.info("NumPy is now available")
388
- else:
389
- logger.error("Failed to install NumPy. Cannot proceed with GOT-OCR.")
390
- raise ImportError(
391
- "NumPy is not available and could not be installed automatically. "
392
- "Please ensure NumPy is installed in your environment. "
393
- "Add the following to your logs for debugging: NUMPY_INSTALLATION_FAILED"
394
- )
395
 
396
- # Check PyTorch availability and try to install if not available
397
- if torch is None:
398
- logger.warning("PyTorch not available, attempting to install it...")
399
- if self._try_install_torch():
400
- # PyTorch is now available
401
- logger.info("PyTorch is now available")
402
- else:
403
- logger.error("Failed to install PyTorch. Cannot proceed with GOT-OCR.")
404
- raise ImportError(
405
- "PyTorch is not available and could not be installed automatically. "
406
- "Please ensure PyTorch is installed in your environment."
407
- )
408
-
409
- # Update GOT availability flag after potential installations
410
- try:
411
- if NUMPY_AVAILABLE and torch is not None:
412
- import transformers
413
- GOT_AVAILABLE = True
414
- logger.info("Updated GOT availability after installations: Available")
415
- else:
416
- GOT_AVAILABLE = False
417
- logger.error("GOT availability after installations: Not Available (missing dependencies)")
418
- except ImportError:
419
- GOT_AVAILABLE = False
420
- logger.error("Transformers not available. GOT-OCR cannot be used.")
421
 
422
- # Check overall GOT availability
423
- if not GOT_AVAILABLE:
424
- if not NUMPY_AVAILABLE:
425
- logger.error("NumPy is still not available after installation attempt.")
426
- raise ImportError(
427
- "NumPy is not available. This is required for GOT-OCR. "
428
- "Please ensure NumPy is installed in your environment. "
429
- "Environment details: Python " + sys.version
430
- )
431
- elif torch is None:
432
- logger.error("PyTorch is still not available after installation attempt.")
433
- raise ImportError(
434
- "PyTorch is not available. This is required for GOT-OCR. "
435
- "Please ensure PyTorch is installed in your environment."
436
- )
437
- else:
438
- logger.error("Other GOT-OCR dependencies missing even though NumPy and PyTorch are available.")
439
- raise ImportError(
440
- "GOT-OCR dependencies not installed. Please install required packages: "
441
- "transformers, tiktoken, verovio, accelerate"
442
- )
443
 
444
- # Check if CUDA is available
445
- cuda_available = torch is not None and hasattr(torch, 'cuda') and hasattr(torch.cuda, 'is_available') and torch.cuda.is_available()
446
- if not cuda_available:
447
- logger.warning("No GPU available. GOT-OCR performance may be severely degraded.")
448
 
449
- # Check file extension
450
  file_path = Path(file_path)
 
 
 
451
  if file_path.suffix.lower() not in ['.jpg', '.jpeg', '.png']:
452
  raise ValueError(
453
- "GOT-OCR only supports JPG and PNG formats. "
454
  f"Received file with extension: {file_path.suffix}"
455
  )
456
 
457
  # Determine OCR type based on method
458
  ocr_type = "format" if ocr_method == "format" else "ocr"
 
459
 
 
460
  try:
461
- # Check if numpy needs to be reloaded
462
- if 'numpy' in sys.modules:
463
- logger.info("NumPy module found in sys.modules, attempting to reload...")
464
- try:
465
- importlib.reload(sys.modules['numpy'])
466
- import numpy as np
467
- logger.info(f"NumPy reloaded successfully: version {np.__version__}")
468
- except Exception as e:
469
- logger.error(f"Error reloading NumPy: {str(e)}")
470
-
471
- # Load the model
472
- self._load_model()
473
-
474
- # Use the model's chat method as shown in the documentation
475
  logger.info(f"Processing image with GOT-OCR: {file_path}")
 
 
476
  try:
477
- # Use ocr_type as a positional argument based on the correct signature
478
- logger.info(f"Using OCR method: {ocr_type}")
479
-
480
- # Temporarily force any PyTorch operations to use float16
481
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
482
  result = self._model.chat(
483
- self._tokenizer,
484
- str(file_path),
485
- ocr_type # Pass as positional arg, not keyword
486
  )
 
487
  except RuntimeError as e:
 
488
  if "bfloat16" in str(e) or "BFloat16" in str(e):
489
- logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
490
- # Try with explicit autocast
 
491
  try:
492
- # More aggressive approach with multiple settings
 
493
 
494
- # Ensure bfloat16 is aliased to float16 globally
495
- if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
496
- logger.info("Forcing bfloat16 to be float16 in exception handler")
497
- torch.bfloat16 = torch.float16
498
 
499
- # Apply patch to the model's config if it exists
500
- if hasattr(self._model, 'config'):
501
- if hasattr(self._model.config, 'torch_dtype'):
502
- logger.info(f"Setting model config dtype from {self._model.config.torch_dtype} to float16")
503
- self._model.config.torch_dtype = torch.float16
504
-
505
- # Try with all possible autocast combinations
506
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
507
- # Temporarily set default dtype
508
- old_dtype = torch.get_default_dtype()
509
- torch.set_default_dtype(torch.float16)
510
-
511
- # Call with positional argument for ocr_type
512
- logger.info("Using fallback with autocast and default dtype set to float16")
513
  result = self._model.chat(
514
  self._tokenizer,
515
  str(file_path),
516
  ocr_type
517
  )
518
-
519
- # Restore default dtype
520
- torch.set_default_dtype(old_dtype)
521
- except Exception as inner_e:
522
- logger.error(f"Error in fallback method: {str(inner_e)}")
523
 
524
- # Last resort: try using torchscript if available
525
- try:
526
- logger.info("Attempting final approach with model.half() and direct call")
527
- # Force model to half precision
528
- self._model = self._model.half()
529
-
530
- # Try direct call with the original method
531
- result = self._model.chat(
532
- self._tokenizer,
533
- str(file_path),
534
- ocr_type
535
- )
536
- except Exception as final_e:
537
- logger.error(f"All fallback approaches failed: {str(final_e)}")
538
- raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(final_e)}")
539
  else:
540
- # Re-raise other errors
541
  raise
542
-
543
- # Return the result directly as markdown
544
- return result
545
-
546
  except Exception as e:
547
- error_type = type(e).__name__
548
 
549
- # Handle specific error types
550
- if torch is not None and hasattr(torch, 'cuda') and error_type == 'OutOfMemoryError':
551
- self.release_model() # Release memory
552
- logger.error("GPU out of memory while processing with GOT-OCR")
553
  raise RuntimeError(
554
  "GPU out of memory while processing with GOT-OCR. "
555
  "Try using a smaller image or a different parser."
556
  )
557
- elif error_type == 'AttributeError' and "get_max_length" in str(e):
558
- logger.error(f"Transformers version compatibility error: {str(e)}")
559
- self.release_model() # Release memory
560
- raise RuntimeError(
561
- "Transformers version compatibility error with GOT-OCR. "
562
- "Please downgrade transformers to version <4.48.0. "
563
- f"Error: {str(e)}"
564
- )
565
- else:
566
- logger.error(f"Error processing document with GOT-OCR: {str(e)}")
567
- raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
568
 
569
- # Register the parser with the registry if dependencies are available
570
  try:
571
- if NUMPY_AVAILABLE and torch is not None:
572
- ParserRegistry.register(GotOcrParser)
573
- logger.info("GOT-OCR parser registered successfully")
574
- else:
575
- missing_deps = []
576
- if not NUMPY_AVAILABLE:
577
- missing_deps.append("NumPy")
578
- if torch is None:
579
- missing_deps.append("PyTorch")
580
- logger.warning(f"GOT-OCR parser not registered: missing dependencies: {', '.join(missing_deps)}")
581
- except Exception as e:
582
- logger.error(f"Error registering GOT-OCR parser: {str(e)}")
 
1
  from pathlib import Path
2
  from typing import Dict, List, Optional, Any, Union
 
 
 
3
  import logging
4
+ import os
5
  import sys
 
6
 
7
+ # Set PyTorch environment variables for T4 compatibility
8
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX"
9
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
10
+ os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  from src.parsers.parser_interface import DocumentParser
13
  from src.parsers.parser_registry import ParserRegistry
 
15
  # Configure logging
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class GotOcrParser(DocumentParser):
19
+ """Parser implementation using GOT-OCR 2.0 for document text extraction.
20
+ Optimized for NVIDIA T4 GPUs with explicit float16 support.
21
+ """
22
 
23
  _model = None
24
  _tokenizer = None
 
47
  return "GOT-OCR 2.0 parser for converting images to text (requires CUDA)"
48
 
49
  @classmethod
50
+ def _check_dependencies(cls) -> bool:
51
+ """Check if all required dependencies are installed."""
52
+ try:
53
+ import numpy
54
+ import torch
55
+ import transformers
56
+ import tiktoken
57
 
58
+ # Check CUDA availability if using torch
59
+ if hasattr(torch, 'cuda') and not torch.cuda.is_available():
60
+ logger.warning("CUDA is not available. GOT-OCR performs best with GPU acceleration.")
61
 
62
+ return True
63
+ except ImportError as e:
64
+ logger.error(f"Missing dependency: {e}")
65
+ return False
66
+
67
+ @classmethod
68
+ def _load_model(cls):
69
+ """Load the GOT-OCR model and tokenizer if not already loaded."""
70
  if cls._model is None or cls._tokenizer is None:
71
  try:
72
+ # Import dependencies inside the method to avoid global import errors
73
+ import torch
74
+ from transformers import AutoModel, AutoTokenizer
75
+
76
  logger.info("Loading GOT-OCR model and tokenizer...")
77
+
78
+ # Load tokenizer
79
  cls._tokenizer = AutoTokenizer.from_pretrained(
80
+ 'stepfun-ai/GOT-OCR2_0',
81
  trust_remote_code=True
82
  )
83
 
84
+ # Determine device
85
+ device_map = 'cuda' if torch.cuda.is_available() else 'auto'
86
+ if device_map == 'cuda':
87
+ logger.info("Using CUDA for model inference")
88
  else:
89
+ logger.warning("Using CPU for model inference (not recommended)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Load model with explicit float16 for T4 compatibility
92
  cls._model = AutoModel.from_pretrained(
93
+ 'stepfun-ai/GOT-OCR2_0',
94
+ trust_remote_code=True,
95
+ low_cpu_mem_usage=True,
96
+ device_map=device_map,
97
  use_safetensors=True,
98
+ torch_dtype=torch.float16, # Force float16 for T4 compatibility
99
+ pad_token_id=cls._tokenizer.eos_token_id
100
  )
101
 
102
+ # Explicitly convert model to half precision (float16)
103
+ cls._model = cls._model.half().eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Move to CUDA if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if device_map == 'cuda':
107
+ cls._model = cls._model.cuda()
 
 
108
 
 
 
 
 
109
  logger.info("GOT-OCR model loaded successfully")
110
+ return True
111
  except Exception as e:
112
  cls._model = None
113
  cls._tokenizer = None
114
  logger.error(f"Failed to load GOT-OCR model: {str(e)}")
115
+ return False
116
+ return True
117
 
118
  @classmethod
119
  def release_model(cls):
120
  """Release the model from memory."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  try:
122
+ import torch
 
 
 
 
 
 
 
 
123
 
124
+ if cls._model is not None:
125
+ del cls._model
126
+ cls._model = None
 
127
 
128
+ if cls._tokenizer is not None:
129
+ del cls._tokenizer
130
+ cls._tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # Clear CUDA cache if available
133
+ if torch.cuda.is_available():
134
+ torch.cuda.empty_cache()
 
135
 
136
+ logger.info("GOT-OCR model released from memory")
 
137
  except Exception as e:
138
+ logger.error(f"Error releasing model: {str(e)}")
 
 
 
139
 
140
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
141
+ """Parse a document using GOT-OCR 2.0.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ Args:
144
+ file_path: Path to the image file
145
+ ocr_method: OCR method to use ('plain' or 'format')
146
+ **kwargs: Additional arguments to pass to the model
147
+
148
+ Returns:
149
+ Extracted text from the image
150
+ """
151
+ # Verify dependencies are installed
152
+ if not self._check_dependencies():
153
+ raise ImportError(
154
+ "Required dependencies are missing. Please install: "
155
+ "torch==2.0.1 torchvision==0.15.2 transformers==4.37.2 "
156
+ "tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0"
157
+ )
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Load model if not already loaded
160
+ if not self._load_model():
161
+ raise RuntimeError("Failed to load GOT-OCR model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ # Import torch here to ensure it's available
164
+ import torch
 
 
165
 
166
+ # Validate file path and extension
167
  file_path = Path(file_path)
168
+ if not file_path.exists():
169
+ raise FileNotFoundError(f"Image file not found: {file_path}")
170
+
171
  if file_path.suffix.lower() not in ['.jpg', '.jpeg', '.png']:
172
  raise ValueError(
173
+ f"GOT-OCR only supports JPG and PNG formats. "
174
  f"Received file with extension: {file_path.suffix}"
175
  )
176
 
177
  # Determine OCR type based on method
178
  ocr_type = "format" if ocr_method == "format" else "ocr"
179
+ logger.info(f"Using OCR method: {ocr_type}")
180
 
181
+ # Process the image
182
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  logger.info(f"Processing image with GOT-OCR: {file_path}")
184
+
185
+ # First attempt: Normal processing with autocast
186
  try:
 
 
 
 
187
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
188
  result = self._model.chat(
189
+ self._tokenizer,
190
+ str(file_path),
191
+ ocr_type
192
  )
193
+ return result
194
  except RuntimeError as e:
195
+ # Check if it's a bfloat16 error
196
  if "bfloat16" in str(e) or "BFloat16" in str(e):
197
+ logger.warning("Encountered bfloat16 error, trying float16 fallback")
198
+
199
+ # Second attempt: More aggressive float16 forcing
200
  try:
201
+ # Ensure model is float16
202
+ self._model = self._model.half()
203
 
204
+ # Set default dtype temporarily
205
+ original_dtype = torch.get_default_dtype()
206
+ torch.set_default_dtype(torch.float16)
 
207
 
 
 
 
 
 
 
 
208
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
 
 
 
 
 
 
209
  result = self._model.chat(
210
  self._tokenizer,
211
  str(file_path),
212
  ocr_type
213
  )
 
 
 
 
 
214
 
215
+ # Restore default dtype
216
+ torch.set_default_dtype(original_dtype)
217
+ return result
218
+ except Exception as inner_e:
219
+ logger.error(f"Float16 fallback failed: {str(inner_e)}")
220
+ raise RuntimeError(
221
+ f"Failed to process image with GOT-OCR: {str(inner_e)}"
222
+ )
 
 
 
 
 
 
 
223
  else:
224
+ # Not a bfloat16 error, re-raise
225
  raise
226
+
 
 
 
227
  except Exception as e:
228
+ logger.error(f"Error processing image with GOT-OCR: {str(e)}")
229
 
230
+ # Handle specific errors with helpful messages
231
+ error_type = type(e).__name__
232
+ if error_type == 'OutOfMemoryError':
233
+ self.release_model()
234
  raise RuntimeError(
235
  "GPU out of memory while processing with GOT-OCR. "
236
  "Try using a smaller image or a different parser."
237
  )
238
+
239
+ # Generic error
240
+ raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
 
 
 
 
 
 
 
 
241
 
242
+ # Try to register the parser
243
  try:
244
+ # Only check basic imports, detailed dependency check happens in parse method
245
+ import numpy
246
+ import torch
247
+ ParserRegistry.register(GotOcrParser)
248
+ logger.info("GOT-OCR parser registered successfully")
249
+ except ImportError as e:
250
+ logger.warning(f"Could not register GOT-OCR parser: {str(e)}")