AnseMin commited on
Commit
84a7af0
·
1 Parent(s): 6372663

Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

Browse files
Files changed (1) hide show
  1. src/parsers/got_ocr_parser.py +11 -1
src/parsers/got_ocr_parser.py CHANGED
@@ -107,13 +107,19 @@ class GotOcrParser(DocumentParser):
107
  logger.warning("No GPU available, falling back to CPU (not recommended)")
108
  device_map = 'auto'
109
 
 
 
 
 
 
110
  cls._model = AutoModel.from_pretrained(
111
  'stepfun-ai/GOT-OCR2_0',
112
  trust_remote_code=True,
113
  low_cpu_mem_usage=True,
114
  device_map=device_map,
115
  use_safetensors=True,
116
- pad_token_id=cls._tokenizer.eos_token_id
 
117
  )
118
 
119
  # Set model to evaluation mode
@@ -121,6 +127,10 @@ class GotOcrParser(DocumentParser):
121
  cls._model = cls._model.eval().cuda()
122
  else:
123
  cls._model = cls._model.eval()
 
 
 
 
124
 
125
  logger.info("GOT-OCR model loaded successfully")
126
  except Exception as e:
 
107
  logger.warning("No GPU available, falling back to CPU (not recommended)")
108
  device_map = 'auto'
109
 
110
+ # Set torch default dtype to float16 since the CUDA device doesn't support bfloat16
111
+ logger.info("Setting default tensor type to float16")
112
+ torch.set_default_tensor_type(torch.FloatTensor)
113
+ torch.set_default_dtype(torch.float16)
114
+
115
  cls._model = AutoModel.from_pretrained(
116
  'stepfun-ai/GOT-OCR2_0',
117
  trust_remote_code=True,
118
  low_cpu_mem_usage=True,
119
  device_map=device_map,
120
  use_safetensors=True,
121
+ pad_token_id=cls._tokenizer.eos_token_id,
122
+ torch_dtype=torch.float16 # Explicitly specify float16 dtype
123
  )
124
 
125
  # Set model to evaluation mode
 
127
  cls._model = cls._model.eval().cuda()
128
  else:
129
  cls._model = cls._model.eval()
130
+
131
+ # Reset default dtype to float32 after model loading
132
+ torch.set_default_dtype(torch.float32)
133
+ torch.set_default_tensor_type(torch.FloatTensor)
134
 
135
  logger.info("GOT-OCR model loaded successfully")
136
  except Exception as e: