Spaces:
Runtime error
Runtime error
Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
Browse files
src/parsers/got_ocr_parser.py
CHANGED
@@ -7,6 +7,11 @@ import logging
|
|
7 |
import sys
|
8 |
import importlib
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
from src.parsers.parser_interface import DocumentParser
|
11 |
from src.parsers.parser_registry import ParserRegistry
|
12 |
|
@@ -122,6 +127,32 @@ class GotOcrParser(DocumentParser):
|
|
122 |
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
123 |
)
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
# Set model to evaluation mode
|
126 |
if device_map == 'cuda':
|
127 |
cls._model = cls._model.eval().cuda()
|
@@ -313,11 +344,38 @@ class GotOcrParser(DocumentParser):
|
|
313 |
|
314 |
# Use the model's chat method as shown in the documentation
|
315 |
logger.info(f"Processing image with GOT-OCR: {file_path}")
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
# Return the result directly as markdown
|
323 |
return result
|
|
|
7 |
import sys
|
8 |
import importlib
|
9 |
|
10 |
+
# Set PyTorch environment variables to force float16 instead of bfloat16
|
11 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
|
12 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
13 |
+
os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
|
14 |
+
|
15 |
from src.parsers.parser_interface import DocumentParser
|
16 |
from src.parsers.parser_registry import ParserRegistry
|
17 |
|
|
|
127 |
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
128 |
)
|
129 |
|
130 |
+
# Patch the model's chat method to use float16 instead of bfloat16
|
131 |
+
logger.info("Patching model to use float16 instead of bfloat16")
|
132 |
+
original_chat = cls._model.chat
|
133 |
+
|
134 |
+
def patched_chat(self, tokenizer, image_path, *args, **kwargs):
|
135 |
+
# Check if patch is working
|
136 |
+
logger.info("Using patched chat method with float16")
|
137 |
+
|
138 |
+
# Set explicit autocast dtype
|
139 |
+
if hasattr(torch.amp, 'autocast'):
|
140 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
141 |
+
try:
|
142 |
+
return original_chat(self, tokenizer, image_path, *args, **kwargs)
|
143 |
+
except RuntimeError as e:
|
144 |
+
if "bfloat16" in str(e):
|
145 |
+
logger.error(f"BFloat16 error encountered despite patching: {e}")
|
146 |
+
raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
|
147 |
+
else:
|
148 |
+
raise
|
149 |
+
else:
|
150 |
+
return original_chat(self, tokenizer, image_path, *args, **kwargs)
|
151 |
+
|
152 |
+
# Apply the patch
|
153 |
+
import types
|
154 |
+
cls._model.chat = types.MethodType(patched_chat, cls._model)
|
155 |
+
|
156 |
# Set model to evaluation mode
|
157 |
if device_map == 'cuda':
|
158 |
cls._model = cls._model.eval().cuda()
|
|
|
344 |
|
345 |
# Use the model's chat method as shown in the documentation
|
346 |
logger.info(f"Processing image with GOT-OCR: {file_path}")
|
347 |
+
try:
|
348 |
+
# First try with patched method
|
349 |
+
result = self._model.chat(
|
350 |
+
self._tokenizer,
|
351 |
+
str(file_path),
|
352 |
+
ocr_type=ocr_type
|
353 |
+
)
|
354 |
+
except RuntimeError as e:
|
355 |
+
if "bfloat16" in str(e) or "BFloat16" in str(e):
|
356 |
+
logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
|
357 |
+
# Try with explicit autocast
|
358 |
+
try:
|
359 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
360 |
+
# Temporarily set default dtype
|
361 |
+
old_dtype = torch.get_default_dtype()
|
362 |
+
torch.set_default_dtype(torch.float16)
|
363 |
+
|
364 |
+
# Call the original method directly
|
365 |
+
result = self._model.chat(
|
366 |
+
self._tokenizer,
|
367 |
+
str(file_path),
|
368 |
+
ocr_type=ocr_type
|
369 |
+
)
|
370 |
+
|
371 |
+
# Restore default dtype
|
372 |
+
torch.set_default_dtype(old_dtype)
|
373 |
+
except Exception as inner_e:
|
374 |
+
logger.error(f"Error in fallback method: {str(inner_e)}")
|
375 |
+
raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(inner_e)}")
|
376 |
+
else:
|
377 |
+
# Re-raise other errors
|
378 |
+
raise
|
379 |
|
380 |
# Return the result directly as markdown
|
381 |
return result
|