AnseMin commited on
Commit
6d39d2f
·
1 Parent(s): e5648b4

Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

Browse files
Files changed (1) hide show
  1. src/parsers/got_ocr_parser.py +141 -9
src/parsers/got_ocr_parser.py CHANGED
@@ -11,6 +11,17 @@ import importlib
11
  os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
13
  os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  from src.parsers.parser_interface import DocumentParser
16
  from src.parsers.parser_registry import ParserRegistry
@@ -117,6 +128,39 @@ class GotOcrParser(DocumentParser):
117
  torch.set_default_tensor_type(torch.FloatTensor)
118
  torch.set_default_dtype(torch.float16)
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  cls._model = AutoModel.from_pretrained(
121
  'stepfun-ai/GOT-OCR2_0',
122
  trust_remote_code=True,
@@ -127,6 +171,35 @@ class GotOcrParser(DocumentParser):
127
  torch_dtype=torch.float16 # Explicitly specify float16 dtype
128
  )
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  # Patch the model's chat method to use float16 instead of bfloat16
131
  logger.info("Patching model to use float16 instead of bfloat16")
132
  original_chat = cls._model.chat
@@ -144,6 +217,11 @@ class GotOcrParser(DocumentParser):
144
  """A patched version of chat method that forces float16 precision"""
145
  logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
146
 
 
 
 
 
 
147
  # Set explicit autocast dtype
148
  if hasattr(torch.amp, 'autocast'):
149
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
@@ -162,11 +240,16 @@ class GotOcrParser(DocumentParser):
162
  except RuntimeError as e:
163
  if "bfloat16" in str(e):
164
  logger.error(f"BFloat16 error encountered despite patching: {e}")
 
 
 
 
 
165
  raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
166
  else:
167
  raise
168
  else:
169
- # Same approach without autocast
170
  try:
171
  # Direct call without 'self' as first arg
172
  return original_chat(tokenizer, image_path, ocr_type, **kwargs)
@@ -183,11 +266,27 @@ class GotOcrParser(DocumentParser):
183
  import types
184
  cls._model.chat = types.MethodType(patched_chat, cls._model)
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  # Set model to evaluation mode
187
  if device_map == 'cuda':
188
- cls._model = cls._model.eval().cuda()
189
  else:
190
- cls._model = cls._model.eval()
191
 
192
  # Reset default dtype to float32 after model loading
193
  torch.set_default_dtype(torch.float32)
@@ -377,22 +476,40 @@ class GotOcrParser(DocumentParser):
377
  try:
378
  # Use ocr_type as a positional argument based on the correct signature
379
  logger.info(f"Using OCR method: {ocr_type}")
380
- result = self._model.chat(
381
- self._tokenizer,
382
- str(file_path),
383
- ocr_type # Pass as positional arg, not keyword
384
- )
 
 
 
385
  except RuntimeError as e:
386
  if "bfloat16" in str(e) or "BFloat16" in str(e):
387
  logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
388
  # Try with explicit autocast
389
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
391
  # Temporarily set default dtype
392
  old_dtype = torch.get_default_dtype()
393
  torch.set_default_dtype(torch.float16)
394
 
395
  # Call with positional argument for ocr_type
 
396
  result = self._model.chat(
397
  self._tokenizer,
398
  str(file_path),
@@ -403,7 +520,22 @@ class GotOcrParser(DocumentParser):
403
  torch.set_default_dtype(old_dtype)
404
  except Exception as inner_e:
405
  logger.error(f"Error in fallback method: {str(inner_e)}")
406
- raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(inner_e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  else:
408
  # Re-raise other errors
409
  raise
 
11
  os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
13
  os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
14
+ os.environ["PYTORCH_DISPATCHER_DISABLE_TORCH_FUNCTION_AUTOGRAD_FALLBACK"] = "1" # Disable fallbacks that might use bfloat16
15
+
16
+ # Add patch for bfloat16 at the module level
17
+ if 'torch' in sys.modules:
18
+ torch_module = sys.modules['torch']
19
+ if hasattr(torch_module, 'bfloat16'):
20
+ # Create a reference to the original bfloat16 function
21
+ original_bfloat16 = torch_module.bfloat16
22
+ # Replace it with float16
23
+ torch_module.bfloat16 = torch_module.float16
24
+ logger.info("Patched torch.bfloat16 to use torch.float16 instead")
25
 
26
  from src.parsers.parser_interface import DocumentParser
27
  from src.parsers.parser_registry import ParserRegistry
 
128
  torch.set_default_tensor_type(torch.FloatTensor)
129
  torch.set_default_dtype(torch.float16)
130
 
131
+ # Aggressively patch torch.autocast to always use float16
132
+ original_autocast = torch.amp.autocast if hasattr(torch.amp, 'autocast') else None
133
+
134
+ if original_autocast:
135
+ def patched_autocast(*args, **kwargs):
136
+ # Force dtype to float16
137
+ kwargs['dtype'] = torch.float16
138
+ return original_autocast(*args, **kwargs)
139
+
140
+ torch.amp.autocast = patched_autocast
141
+ logger.info("Patched torch.amp.autocast to always use float16")
142
+
143
+ # Patch tensor casting methods for bfloat16
144
+ if hasattr(torch, 'Tensor'):
145
+ if hasattr(torch.Tensor, 'to'):
146
+ original_to = torch.Tensor.to
147
+ def patched_to(self, *args, **kwargs):
148
+ # If the first arg is a dtype and it's bfloat16, replace with float16
149
+ if args and args[0] == torch.bfloat16:
150
+ logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
151
+ args = list(args)
152
+ args[0] = torch.float16
153
+ args = tuple(args)
154
+ # If dtype is specified in kwargs and it's bfloat16, replace with float16
155
+ if kwargs.get('dtype') == torch.bfloat16:
156
+ logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
157
+ kwargs['dtype'] = torch.float16
158
+ return original_to(self, *args, **kwargs)
159
+
160
+ torch.Tensor.to = patched_to
161
+ logger.info("Patched torch.Tensor.to method to prevent bfloat16 usage")
162
+
163
+ # Load the model with explicit float16 dtype
164
  cls._model = AutoModel.from_pretrained(
165
  'stepfun-ai/GOT-OCR2_0',
166
  trust_remote_code=True,
 
171
  torch_dtype=torch.float16 # Explicitly specify float16 dtype
172
  )
173
 
174
+ # Ensure all model parameters are float16
175
+ for param in cls._model.parameters():
176
+ param.data = param.data.to(torch.float16)
177
+
178
+ # Examine model internals to find any direct bfloat16 usage
179
+ def find_and_patch_bfloat16_attributes(module, path=""):
180
+ for name, child in module.named_children():
181
+ child_path = f"{path}.{name}" if path else name
182
+ # Check if any attribute contains "bfloat16" in its name
183
+ for attr_name in dir(child):
184
+ if "bfloat16" in attr_name.lower():
185
+ try:
186
+ # Try to get the attribute
187
+ attr_value = getattr(child, attr_name)
188
+ logger.warning(f"Found potential bfloat16 usage at {child_path}.{attr_name}")
189
+ # Try to replace with float16 equivalent if it exists
190
+ float16_attr_name = attr_name.replace("bfloat16", "float16").replace("bf16", "fp16")
191
+ if hasattr(child, float16_attr_name):
192
+ logger.info(f"Replacing {attr_name} with {float16_attr_name}")
193
+ setattr(child, attr_name, getattr(child, float16_attr_name))
194
+ except Exception as e:
195
+ logger.error(f"Error examining attribute {attr_name}: {e}")
196
+ # Recursively check child modules
197
+ find_and_patch_bfloat16_attributes(child, child_path)
198
+
199
+ # Apply the internal examination
200
+ logger.info("Examining model for potential bfloat16 usage...")
201
+ find_and_patch_bfloat16_attributes(cls._model)
202
+
203
  # Patch the model's chat method to use float16 instead of bfloat16
204
  logger.info("Patching model to use float16 instead of bfloat16")
205
  original_chat = cls._model.chat
 
217
  """A patched version of chat method that forces float16 precision"""
218
  logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
219
 
220
+ # Force any bfloat16 tensors to float16
221
+ if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
222
+ torch.bfloat16 = torch.float16
223
+ logger.info("Forcing torch.bfloat16 to be torch.float16 within chat method")
224
+
225
  # Set explicit autocast dtype
226
  if hasattr(torch.amp, 'autocast'):
227
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
 
240
  except RuntimeError as e:
241
  if "bfloat16" in str(e):
242
  logger.error(f"BFloat16 error encountered despite patching: {e}")
243
+ # More aggressive handling
244
+ if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
245
+ logger.info("Attempting with torch.cuda.amp.autocast as last resort")
246
+ with torch.cuda.amp.autocast(dtype=torch.float16):
247
+ return original_chat(tokenizer, str(image_path), ocr_type, **kwargs)
248
  raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
249
  else:
250
  raise
251
  else:
252
+ # If autocast is not available, try to manually ensure everything is float16
253
  try:
254
  # Direct call without 'self' as first arg
255
  return original_chat(tokenizer, image_path, ocr_type, **kwargs)
 
266
  import types
267
  cls._model.chat = types.MethodType(patched_chat, cls._model)
268
 
269
+ # Check if the model has a cast_to_bfloat16 method and override it
270
+ if hasattr(cls._model, 'cast_to_bfloat16'):
271
+ original_cast = cls._model.cast_to_bfloat16
272
+ def patched_cast(self, *args, **kwargs):
273
+ logger.info("Intercepted attempt to cast model to bfloat16, using float16 instead")
274
+ # If the model has a cast_to_float16 method, use that instead
275
+ if hasattr(self, 'cast_to_float16'):
276
+ return self.cast_to_float16(*args, **kwargs)
277
+ # Otherwise, cast all parameters manually
278
+ for param in self.parameters():
279
+ param.data = param.data.to(torch.float16)
280
+ return self
281
+
282
+ cls._model.cast_to_bfloat16 = types.MethodType(patched_cast, cls._model)
283
+ logger.info("Patched model.cast_to_bfloat16 method")
284
+
285
  # Set model to evaluation mode
286
  if device_map == 'cuda':
287
+ cls._model = cls._model.eval().cuda().half() # Explicitly cast to half precision (float16)
288
  else:
289
+ cls._model = cls._model.eval().half() # Explicitly cast to half precision (float16)
290
 
291
  # Reset default dtype to float32 after model loading
292
  torch.set_default_dtype(torch.float32)
 
476
  try:
477
  # Use ocr_type as a positional argument based on the correct signature
478
  logger.info(f"Using OCR method: {ocr_type}")
479
+
480
+ # Temporarily force any PyTorch operations to use float16
481
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
482
+ result = self._model.chat(
483
+ self._tokenizer,
484
+ str(file_path),
485
+ ocr_type # Pass as positional arg, not keyword
486
+ )
487
  except RuntimeError as e:
488
  if "bfloat16" in str(e) or "BFloat16" in str(e):
489
  logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
490
  # Try with explicit autocast
491
  try:
492
+ # More aggressive approach with multiple settings
493
+
494
+ # Ensure bfloat16 is aliased to float16 globally
495
+ if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
496
+ logger.info("Forcing bfloat16 to be float16 in exception handler")
497
+ torch.bfloat16 = torch.float16
498
+
499
+ # Apply patch to the model's config if it exists
500
+ if hasattr(self._model, 'config'):
501
+ if hasattr(self._model.config, 'torch_dtype'):
502
+ logger.info(f"Setting model config dtype from {self._model.config.torch_dtype} to float16")
503
+ self._model.config.torch_dtype = torch.float16
504
+
505
+ # Try with all possible autocast combinations
506
  with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
507
  # Temporarily set default dtype
508
  old_dtype = torch.get_default_dtype()
509
  torch.set_default_dtype(torch.float16)
510
 
511
  # Call with positional argument for ocr_type
512
+ logger.info("Using fallback with autocast and default dtype set to float16")
513
  result = self._model.chat(
514
  self._tokenizer,
515
  str(file_path),
 
520
  torch.set_default_dtype(old_dtype)
521
  except Exception as inner_e:
522
  logger.error(f"Error in fallback method: {str(inner_e)}")
523
+
524
+ # Last resort: try using torchscript if available
525
+ try:
526
+ logger.info("Attempting final approach with model.half() and direct call")
527
+ # Force model to half precision
528
+ self._model = self._model.half()
529
+
530
+ # Try direct call with the original method
531
+ result = self._model.chat(
532
+ self._tokenizer,
533
+ str(file_path),
534
+ ocr_type
535
+ )
536
+ except Exception as final_e:
537
+ logger.error(f"All fallback approaches failed: {str(final_e)}")
538
+ raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(final_e)}")
539
  else:
540
  # Re-raise other errors
541
  raise