Spaces:
Runtime error
Runtime error
Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
Browse files- src/parsers/got_ocr_parser.py +141 -9
src/parsers/got_ocr_parser.py
CHANGED
@@ -11,6 +11,17 @@ import importlib
|
|
11 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
|
12 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
13 |
os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
from src.parsers.parser_interface import DocumentParser
|
16 |
from src.parsers.parser_registry import ParserRegistry
|
@@ -117,6 +128,39 @@ class GotOcrParser(DocumentParser):
|
|
117 |
torch.set_default_tensor_type(torch.FloatTensor)
|
118 |
torch.set_default_dtype(torch.float16)
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
cls._model = AutoModel.from_pretrained(
|
121 |
'stepfun-ai/GOT-OCR2_0',
|
122 |
trust_remote_code=True,
|
@@ -127,6 +171,35 @@ class GotOcrParser(DocumentParser):
|
|
127 |
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
128 |
)
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
# Patch the model's chat method to use float16 instead of bfloat16
|
131 |
logger.info("Patching model to use float16 instead of bfloat16")
|
132 |
original_chat = cls._model.chat
|
@@ -144,6 +217,11 @@ class GotOcrParser(DocumentParser):
|
|
144 |
"""A patched version of chat method that forces float16 precision"""
|
145 |
logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
|
146 |
|
|
|
|
|
|
|
|
|
|
|
147 |
# Set explicit autocast dtype
|
148 |
if hasattr(torch.amp, 'autocast'):
|
149 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
@@ -162,11 +240,16 @@ class GotOcrParser(DocumentParser):
|
|
162 |
except RuntimeError as e:
|
163 |
if "bfloat16" in str(e):
|
164 |
logger.error(f"BFloat16 error encountered despite patching: {e}")
|
|
|
|
|
|
|
|
|
|
|
165 |
raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
|
166 |
else:
|
167 |
raise
|
168 |
else:
|
169 |
-
#
|
170 |
try:
|
171 |
# Direct call without 'self' as first arg
|
172 |
return original_chat(tokenizer, image_path, ocr_type, **kwargs)
|
@@ -183,11 +266,27 @@ class GotOcrParser(DocumentParser):
|
|
183 |
import types
|
184 |
cls._model.chat = types.MethodType(patched_chat, cls._model)
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
# Set model to evaluation mode
|
187 |
if device_map == 'cuda':
|
188 |
-
cls._model = cls._model.eval().cuda()
|
189 |
else:
|
190 |
-
cls._model = cls._model.eval()
|
191 |
|
192 |
# Reset default dtype to float32 after model loading
|
193 |
torch.set_default_dtype(torch.float32)
|
@@ -377,22 +476,40 @@ class GotOcrParser(DocumentParser):
|
|
377 |
try:
|
378 |
# Use ocr_type as a positional argument based on the correct signature
|
379 |
logger.info(f"Using OCR method: {ocr_type}")
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
385 |
except RuntimeError as e:
|
386 |
if "bfloat16" in str(e) or "BFloat16" in str(e):
|
387 |
logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
|
388 |
# Try with explicit autocast
|
389 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
391 |
# Temporarily set default dtype
|
392 |
old_dtype = torch.get_default_dtype()
|
393 |
torch.set_default_dtype(torch.float16)
|
394 |
|
395 |
# Call with positional argument for ocr_type
|
|
|
396 |
result = self._model.chat(
|
397 |
self._tokenizer,
|
398 |
str(file_path),
|
@@ -403,7 +520,22 @@ class GotOcrParser(DocumentParser):
|
|
403 |
torch.set_default_dtype(old_dtype)
|
404 |
except Exception as inner_e:
|
405 |
logger.error(f"Error in fallback method: {str(inner_e)}")
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
else:
|
408 |
# Re-raise other errors
|
409 |
raise
|
|
|
11 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
|
12 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
13 |
os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
|
14 |
+
os.environ["PYTORCH_DISPATCHER_DISABLE_TORCH_FUNCTION_AUTOGRAD_FALLBACK"] = "1" # Disable fallbacks that might use bfloat16
|
15 |
+
|
16 |
+
# Add patch for bfloat16 at the module level
|
17 |
+
if 'torch' in sys.modules:
|
18 |
+
torch_module = sys.modules['torch']
|
19 |
+
if hasattr(torch_module, 'bfloat16'):
|
20 |
+
# Create a reference to the original bfloat16 function
|
21 |
+
original_bfloat16 = torch_module.bfloat16
|
22 |
+
# Replace it with float16
|
23 |
+
torch_module.bfloat16 = torch_module.float16
|
24 |
+
logger.info("Patched torch.bfloat16 to use torch.float16 instead")
|
25 |
|
26 |
from src.parsers.parser_interface import DocumentParser
|
27 |
from src.parsers.parser_registry import ParserRegistry
|
|
|
128 |
torch.set_default_tensor_type(torch.FloatTensor)
|
129 |
torch.set_default_dtype(torch.float16)
|
130 |
|
131 |
+
# Aggressively patch torch.autocast to always use float16
|
132 |
+
original_autocast = torch.amp.autocast if hasattr(torch.amp, 'autocast') else None
|
133 |
+
|
134 |
+
if original_autocast:
|
135 |
+
def patched_autocast(*args, **kwargs):
|
136 |
+
# Force dtype to float16
|
137 |
+
kwargs['dtype'] = torch.float16
|
138 |
+
return original_autocast(*args, **kwargs)
|
139 |
+
|
140 |
+
torch.amp.autocast = patched_autocast
|
141 |
+
logger.info("Patched torch.amp.autocast to always use float16")
|
142 |
+
|
143 |
+
# Patch tensor casting methods for bfloat16
|
144 |
+
if hasattr(torch, 'Tensor'):
|
145 |
+
if hasattr(torch.Tensor, 'to'):
|
146 |
+
original_to = torch.Tensor.to
|
147 |
+
def patched_to(self, *args, **kwargs):
|
148 |
+
# If the first arg is a dtype and it's bfloat16, replace with float16
|
149 |
+
if args and args[0] == torch.bfloat16:
|
150 |
+
logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
|
151 |
+
args = list(args)
|
152 |
+
args[0] = torch.float16
|
153 |
+
args = tuple(args)
|
154 |
+
# If dtype is specified in kwargs and it's bfloat16, replace with float16
|
155 |
+
if kwargs.get('dtype') == torch.bfloat16:
|
156 |
+
logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
|
157 |
+
kwargs['dtype'] = torch.float16
|
158 |
+
return original_to(self, *args, **kwargs)
|
159 |
+
|
160 |
+
torch.Tensor.to = patched_to
|
161 |
+
logger.info("Patched torch.Tensor.to method to prevent bfloat16 usage")
|
162 |
+
|
163 |
+
# Load the model with explicit float16 dtype
|
164 |
cls._model = AutoModel.from_pretrained(
|
165 |
'stepfun-ai/GOT-OCR2_0',
|
166 |
trust_remote_code=True,
|
|
|
171 |
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
172 |
)
|
173 |
|
174 |
+
# Ensure all model parameters are float16
|
175 |
+
for param in cls._model.parameters():
|
176 |
+
param.data = param.data.to(torch.float16)
|
177 |
+
|
178 |
+
# Examine model internals to find any direct bfloat16 usage
|
179 |
+
def find_and_patch_bfloat16_attributes(module, path=""):
|
180 |
+
for name, child in module.named_children():
|
181 |
+
child_path = f"{path}.{name}" if path else name
|
182 |
+
# Check if any attribute contains "bfloat16" in its name
|
183 |
+
for attr_name in dir(child):
|
184 |
+
if "bfloat16" in attr_name.lower():
|
185 |
+
try:
|
186 |
+
# Try to get the attribute
|
187 |
+
attr_value = getattr(child, attr_name)
|
188 |
+
logger.warning(f"Found potential bfloat16 usage at {child_path}.{attr_name}")
|
189 |
+
# Try to replace with float16 equivalent if it exists
|
190 |
+
float16_attr_name = attr_name.replace("bfloat16", "float16").replace("bf16", "fp16")
|
191 |
+
if hasattr(child, float16_attr_name):
|
192 |
+
logger.info(f"Replacing {attr_name} with {float16_attr_name}")
|
193 |
+
setattr(child, attr_name, getattr(child, float16_attr_name))
|
194 |
+
except Exception as e:
|
195 |
+
logger.error(f"Error examining attribute {attr_name}: {e}")
|
196 |
+
# Recursively check child modules
|
197 |
+
find_and_patch_bfloat16_attributes(child, child_path)
|
198 |
+
|
199 |
+
# Apply the internal examination
|
200 |
+
logger.info("Examining model for potential bfloat16 usage...")
|
201 |
+
find_and_patch_bfloat16_attributes(cls._model)
|
202 |
+
|
203 |
# Patch the model's chat method to use float16 instead of bfloat16
|
204 |
logger.info("Patching model to use float16 instead of bfloat16")
|
205 |
original_chat = cls._model.chat
|
|
|
217 |
"""A patched version of chat method that forces float16 precision"""
|
218 |
logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
|
219 |
|
220 |
+
# Force any bfloat16 tensors to float16
|
221 |
+
if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
|
222 |
+
torch.bfloat16 = torch.float16
|
223 |
+
logger.info("Forcing torch.bfloat16 to be torch.float16 within chat method")
|
224 |
+
|
225 |
# Set explicit autocast dtype
|
226 |
if hasattr(torch.amp, 'autocast'):
|
227 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
|
|
240 |
except RuntimeError as e:
|
241 |
if "bfloat16" in str(e):
|
242 |
logger.error(f"BFloat16 error encountered despite patching: {e}")
|
243 |
+
# More aggressive handling
|
244 |
+
if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
|
245 |
+
logger.info("Attempting with torch.cuda.amp.autocast as last resort")
|
246 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
247 |
+
return original_chat(tokenizer, str(image_path), ocr_type, **kwargs)
|
248 |
raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
|
249 |
else:
|
250 |
raise
|
251 |
else:
|
252 |
+
# If autocast is not available, try to manually ensure everything is float16
|
253 |
try:
|
254 |
# Direct call without 'self' as first arg
|
255 |
return original_chat(tokenizer, image_path, ocr_type, **kwargs)
|
|
|
266 |
import types
|
267 |
cls._model.chat = types.MethodType(patched_chat, cls._model)
|
268 |
|
269 |
+
# Check if the model has a cast_to_bfloat16 method and override it
|
270 |
+
if hasattr(cls._model, 'cast_to_bfloat16'):
|
271 |
+
original_cast = cls._model.cast_to_bfloat16
|
272 |
+
def patched_cast(self, *args, **kwargs):
|
273 |
+
logger.info("Intercepted attempt to cast model to bfloat16, using float16 instead")
|
274 |
+
# If the model has a cast_to_float16 method, use that instead
|
275 |
+
if hasattr(self, 'cast_to_float16'):
|
276 |
+
return self.cast_to_float16(*args, **kwargs)
|
277 |
+
# Otherwise, cast all parameters manually
|
278 |
+
for param in self.parameters():
|
279 |
+
param.data = param.data.to(torch.float16)
|
280 |
+
return self
|
281 |
+
|
282 |
+
cls._model.cast_to_bfloat16 = types.MethodType(patched_cast, cls._model)
|
283 |
+
logger.info("Patched model.cast_to_bfloat16 method")
|
284 |
+
|
285 |
# Set model to evaluation mode
|
286 |
if device_map == 'cuda':
|
287 |
+
cls._model = cls._model.eval().cuda().half() # Explicitly cast to half precision (float16)
|
288 |
else:
|
289 |
+
cls._model = cls._model.eval().half() # Explicitly cast to half precision (float16)
|
290 |
|
291 |
# Reset default dtype to float32 after model loading
|
292 |
torch.set_default_dtype(torch.float32)
|
|
|
476 |
try:
|
477 |
# Use ocr_type as a positional argument based on the correct signature
|
478 |
logger.info(f"Using OCR method: {ocr_type}")
|
479 |
+
|
480 |
+
# Temporarily force any PyTorch operations to use float16
|
481 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
482 |
+
result = self._model.chat(
|
483 |
+
self._tokenizer,
|
484 |
+
str(file_path),
|
485 |
+
ocr_type # Pass as positional arg, not keyword
|
486 |
+
)
|
487 |
except RuntimeError as e:
|
488 |
if "bfloat16" in str(e) or "BFloat16" in str(e):
|
489 |
logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
|
490 |
# Try with explicit autocast
|
491 |
try:
|
492 |
+
# More aggressive approach with multiple settings
|
493 |
+
|
494 |
+
# Ensure bfloat16 is aliased to float16 globally
|
495 |
+
if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
|
496 |
+
logger.info("Forcing bfloat16 to be float16 in exception handler")
|
497 |
+
torch.bfloat16 = torch.float16
|
498 |
+
|
499 |
+
# Apply patch to the model's config if it exists
|
500 |
+
if hasattr(self._model, 'config'):
|
501 |
+
if hasattr(self._model.config, 'torch_dtype'):
|
502 |
+
logger.info(f"Setting model config dtype from {self._model.config.torch_dtype} to float16")
|
503 |
+
self._model.config.torch_dtype = torch.float16
|
504 |
+
|
505 |
+
# Try with all possible autocast combinations
|
506 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
507 |
# Temporarily set default dtype
|
508 |
old_dtype = torch.get_default_dtype()
|
509 |
torch.set_default_dtype(torch.float16)
|
510 |
|
511 |
# Call with positional argument for ocr_type
|
512 |
+
logger.info("Using fallback with autocast and default dtype set to float16")
|
513 |
result = self._model.chat(
|
514 |
self._tokenizer,
|
515 |
str(file_path),
|
|
|
520 |
torch.set_default_dtype(old_dtype)
|
521 |
except Exception as inner_e:
|
522 |
logger.error(f"Error in fallback method: {str(inner_e)}")
|
523 |
+
|
524 |
+
# Last resort: try using torchscript if available
|
525 |
+
try:
|
526 |
+
logger.info("Attempting final approach with model.half() and direct call")
|
527 |
+
# Force model to half precision
|
528 |
+
self._model = self._model.half()
|
529 |
+
|
530 |
+
# Try direct call with the original method
|
531 |
+
result = self._model.chat(
|
532 |
+
self._tokenizer,
|
533 |
+
str(file_path),
|
534 |
+
ocr_type
|
535 |
+
)
|
536 |
+
except Exception as final_e:
|
537 |
+
logger.error(f"All fallback approaches failed: {str(final_e)}")
|
538 |
+
raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(final_e)}")
|
539 |
else:
|
540 |
# Re-raise other errors
|
541 |
raise
|