AnseMin commited on
Commit
d95f73e
·
1 Parent(s): 84a7af0

Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

Browse files
Files changed (1) hide show
  1. src/parsers/got_ocr_parser.py +63 -5
src/parsers/got_ocr_parser.py CHANGED
@@ -7,6 +7,11 @@ import logging
7
  import sys
8
  import importlib
9
 
 
 
 
 
 
10
  from src.parsers.parser_interface import DocumentParser
11
  from src.parsers.parser_registry import ParserRegistry
12
 
@@ -122,6 +127,32 @@ class GotOcrParser(DocumentParser):
122
  torch_dtype=torch.float16 # Explicitly specify float16 dtype
123
  )
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # Set model to evaluation mode
126
  if device_map == 'cuda':
127
  cls._model = cls._model.eval().cuda()
@@ -313,11 +344,38 @@ class GotOcrParser(DocumentParser):
313
 
314
  # Use the model's chat method as shown in the documentation
315
  logger.info(f"Processing image with GOT-OCR: {file_path}")
316
- result = self._model.chat(
317
- self._tokenizer,
318
- str(file_path),
319
- ocr_type=ocr_type
320
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  # Return the result directly as markdown
323
  return result
 
7
  import sys
8
  import importlib
9
 
10
+ # Set PyTorch environment variables to force float16 instead of bfloat16
11
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
13
+ os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
14
+
15
  from src.parsers.parser_interface import DocumentParser
16
  from src.parsers.parser_registry import ParserRegistry
17
 
 
127
  torch_dtype=torch.float16 # Explicitly specify float16 dtype
128
  )
129
 
130
+ # Patch the model's chat method to use float16 instead of bfloat16
131
+ logger.info("Patching model to use float16 instead of bfloat16")
132
+ original_chat = cls._model.chat
133
+
134
+ def patched_chat(self, tokenizer, image_path, *args, **kwargs):
135
+ # Check if patch is working
136
+ logger.info("Using patched chat method with float16")
137
+
138
+ # Set explicit autocast dtype
139
+ if hasattr(torch.amp, 'autocast'):
140
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
141
+ try:
142
+ return original_chat(self, tokenizer, image_path, *args, **kwargs)
143
+ except RuntimeError as e:
144
+ if "bfloat16" in str(e):
145
+ logger.error(f"BFloat16 error encountered despite patching: {e}")
146
+ raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
147
+ else:
148
+ raise
149
+ else:
150
+ return original_chat(self, tokenizer, image_path, *args, **kwargs)
151
+
152
+ # Apply the patch
153
+ import types
154
+ cls._model.chat = types.MethodType(patched_chat, cls._model)
155
+
156
  # Set model to evaluation mode
157
  if device_map == 'cuda':
158
  cls._model = cls._model.eval().cuda()
 
344
 
345
  # Use the model's chat method as shown in the documentation
346
  logger.info(f"Processing image with GOT-OCR: {file_path}")
347
+ try:
348
+ # First try with patched method
349
+ result = self._model.chat(
350
+ self._tokenizer,
351
+ str(file_path),
352
+ ocr_type=ocr_type
353
+ )
354
+ except RuntimeError as e:
355
+ if "bfloat16" in str(e) or "BFloat16" in str(e):
356
+ logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
357
+ # Try with explicit autocast
358
+ try:
359
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
360
+ # Temporarily set default dtype
361
+ old_dtype = torch.get_default_dtype()
362
+ torch.set_default_dtype(torch.float16)
363
+
364
+ # Call the original method directly
365
+ result = self._model.chat(
366
+ self._tokenizer,
367
+ str(file_path),
368
+ ocr_type=ocr_type
369
+ )
370
+
371
+ # Restore default dtype
372
+ torch.set_default_dtype(old_dtype)
373
+ except Exception as inner_e:
374
+ logger.error(f"Error in fallback method: {str(inner_e)}")
375
+ raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(inner_e)}")
376
+ else:
377
+ # Re-raise other errors
378
+ raise
379
 
380
  # Return the result directly as markdown
381
  return result