Spaces:
Runtime error
Runtime error
Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
Browse files
src/parsers/got_ocr_parser.py
CHANGED
@@ -107,13 +107,19 @@ class GotOcrParser(DocumentParser):
|
|
107 |
logger.warning("No GPU available, falling back to CPU (not recommended)")
|
108 |
device_map = 'auto'
|
109 |
|
|
|
|
|
|
|
|
|
|
|
110 |
cls._model = AutoModel.from_pretrained(
|
111 |
'stepfun-ai/GOT-OCR2_0',
|
112 |
trust_remote_code=True,
|
113 |
low_cpu_mem_usage=True,
|
114 |
device_map=device_map,
|
115 |
use_safetensors=True,
|
116 |
-
pad_token_id=cls._tokenizer.eos_token_id
|
|
|
117 |
)
|
118 |
|
119 |
# Set model to evaluation mode
|
@@ -121,6 +127,10 @@ class GotOcrParser(DocumentParser):
|
|
121 |
cls._model = cls._model.eval().cuda()
|
122 |
else:
|
123 |
cls._model = cls._model.eval()
|
|
|
|
|
|
|
|
|
124 |
|
125 |
logger.info("GOT-OCR model loaded successfully")
|
126 |
except Exception as e:
|
|
|
107 |
logger.warning("No GPU available, falling back to CPU (not recommended)")
|
108 |
device_map = 'auto'
|
109 |
|
110 |
+
# Set torch default dtype to float16 since the CUDA device doesn't support bfloat16
|
111 |
+
logger.info("Setting default tensor type to float16")
|
112 |
+
torch.set_default_tensor_type(torch.FloatTensor)
|
113 |
+
torch.set_default_dtype(torch.float16)
|
114 |
+
|
115 |
cls._model = AutoModel.from_pretrained(
|
116 |
'stepfun-ai/GOT-OCR2_0',
|
117 |
trust_remote_code=True,
|
118 |
low_cpu_mem_usage=True,
|
119 |
device_map=device_map,
|
120 |
use_safetensors=True,
|
121 |
+
pad_token_id=cls._tokenizer.eos_token_id,
|
122 |
+
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
123 |
)
|
124 |
|
125 |
# Set model to evaluation mode
|
|
|
127 |
cls._model = cls._model.eval().cuda()
|
128 |
else:
|
129 |
cls._model = cls._model.eval()
|
130 |
+
|
131 |
+
# Reset default dtype to float32 after model loading
|
132 |
+
torch.set_default_dtype(torch.float32)
|
133 |
+
torch.set_default_tensor_type(torch.FloatTensor)
|
134 |
|
135 |
logger.info("GOT-OCR model loaded successfully")
|
136 |
except Exception as e:
|