File size: 15,038 Bytes
36e49b4
 
 
 
 
 
 
dda982a
b61bec6
0f5865d
76a030b
0f5865d
 
 
4cac30a
d95f73e
dda982a
 
0f5865d
5b7f920
 
dda982a
 
 
5bb2b30
dda982a
3415bc4
 
 
 
dda982a
4fab3b3
0f5865d
4fab3b3
 
b61bec6
dda982a
3415bc4
 
dda982a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4523ddf
 
 
 
dda982a
b61bec6
 
 
 
 
17cb3f3
3415bc4
 
b61bec6
 
 
 
 
dda982a
b61bec6
76a030b
b61bec6
 
3415bc4
b61bec6
 
 
ad248f7
b61bec6
3415bc4
b61bec6
 
 
4fab3b3
b61bec6
8efc942
b61bec6
dda982a
b61bec6
 
 
dda982a
 
b61bec6
dda982a
 
 
3415bc4
4fab3b3
dda982a
4fab3b3
 
0f5865d
4cac30a
 
 
 
 
 
 
 
 
 
 
 
4fab3b3
dda982a
4fab3b3
 
4cac30a
 
 
 
 
4fab3b3
3415bc4
 
4fab3b3
 
4cac30a
4fab3b3
0f5865d
8efc942
b61bec6
8efc942
b61bec6
 
 
8efc942
 
 
 
4fab3b3
 
 
 
 
3415bc4
 
 
 
 
4cac30a
 
 
 
 
b61bec6
 
 
610b772
4cac30a
 
 
 
 
 
 
 
 
 
 
 
 
3415bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33f1b65
 
 
 
3415bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cac30a
3415bc4
 
 
4cac30a
3415bc4
 
 
 
4fab3b3
3415bc4
 
4fab3b3
 
 
 
3415bc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fab3b3
 
 
 
3415bc4
4fab3b3
 
 
 
 
3415bc4
4fab3b3
 
3415bc4
 
4fab3b3
 
 
 
3415bc4
4fab3b3
 
 
 
33f1b65
 
 
4fab3b3
 
3415bc4
4fab3b3
 
 
 
 
3415bc4
4fab3b3
 
3415bc4
 
4fab3b3
 
 
 
3415bc4
4fab3b3
 
 
 
 
3415bc4
 
4fab3b3
3415bc4
 
 
 
 
 
 
 
4fab3b3
3415bc4
610b772
 
4cac30a
610b772
3415bc4
4fab3b3
4cac30a
4fab3b3
0f5865d
 
 
3415bc4
 
dda982a
b61bec6
8efc942
3415bc4
b61bec6
3415bc4
4cac30a
b61bec6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
# Import spaces module for ZeroGPU support - Must be first import
try:
    import spaces
    HAS_SPACES = True
except ImportError:
    HAS_SPACES = False

from pathlib import Path
import os
import logging
import sys
import tempfile
import shutil
from typing import Dict, List, Optional, Any, Union
import copy

from src.parsers.parser_interface import DocumentParser
from src.parsers.parser_registry import ParserRegistry

# Import latex2markdown for conversion - No longer needed, using Gemini API
# import latex2markdown

# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# Constants
MODEL_NAME = "stepfun-ai/GOT-OCR-2.0-hf"
STOP_STR = "<|im_end|>"

class GotOcrParser(DocumentParser):
    """Parser implementation using GOT-OCR 2.0 for document text extraction using transformers.
    
    This implementation uses the transformers model directly for better integration with
    ZeroGPU and avoids subprocess complexity.
    """
    
    # Class variables to hold model information only (not the actual model)
    _model_loaded = False
    
    @classmethod
    def get_name(cls) -> str:
        return "GOT-OCR (jpg,png only)"
    
    @classmethod
    def get_supported_ocr_methods(cls) -> List[Dict[str, Any]]:
        return [
            {
                "id": "plain",
                "name": "Plain Text",
                "default_params": {}
            },
            {
                "id": "format",
                "name": "Formatted Text",
                "default_params": {}
            }
        ]
    
    @classmethod
    def get_description(cls) -> str:
        return "GOT-OCR 2.0 parser for converting images to text (requires CUDA)"
    
    @classmethod
    def _check_dependencies(cls) -> bool:
        """Check if all required dependencies are installed."""
        try:
            import torch
            import transformers
            
            # Only check if the modules are importable, DO NOT use torch.cuda here
            # as it would initialize CUDA in the main process
            return True
        except ImportError as e:
            logger.error(f"Missing dependency: {e}")
            return False
    
    def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
        """Parse a document using GOT-OCR 2.0.
        
        Args:
            file_path: Path to the image file
            ocr_method: OCR method to use ('plain', 'format')
            **kwargs: Additional arguments to pass to the model
            
        Returns:
            Extracted text from the image, converted to Markdown if formatted
        """
        # Verify dependencies are installed without initializing CUDA
        if not self._check_dependencies():
            raise ImportError(
                "Required dependencies are missing. Please install: "
                "torch transformers"
            )
        
        # Validate file path and extension
        file_path = Path(file_path)
        if not file_path.exists():
            raise FileNotFoundError(f"Image file not found: {file_path}")
        
        if file_path.suffix.lower() not in ['.jpg', '.jpeg', '.png']:
            raise ValueError(
                f"GOT-OCR only supports JPG and PNG formats. "
                f"Received file with extension: {file_path.suffix}"
            )
        
        # Determine OCR mode based on method
        use_format = ocr_method == "format"
        
        # Log the OCR method being used
        logger.info(f"Using OCR method: {ocr_method or 'plain'}")
        
        # Filter kwargs to remove any objects that can't be pickled (like thread locks)
        safe_kwargs = {}
        for key, value in kwargs.items():
            # Skip thread locks and unpicklable objects
            if not key.startswith('_') and not isinstance(value, type):
                try:
                    # Test if it can be copied - this helps identify unpicklable objects
                    copy.deepcopy(value)
                    safe_kwargs[key] = value
                except (TypeError, pickle.PickleError):
                    logger.warning(f"Skipping unpicklable kwarg: {key}")
        
        # Process the image using transformers
        try:
            # Use the spaces.GPU decorator if available
            if HAS_SPACES:
                # Use string path instead of Path object for better pickling
                image_path_str = str(file_path)
                
                # Call the wrapper function that handles ZeroGPU safely
                return self._safe_gpu_process(image_path_str, use_format, **safe_kwargs)
            else:
                # Fallback for environments without spaces
                return self._process_image_without_gpu(
                    str(file_path), 
                    use_format=use_format,
                    **safe_kwargs
                )
            
        except Exception as e:
            logger.error(f"Error processing image with GOT-OCR: {str(e)}")
            
            # Handle specific errors with helpful messages
            error_type = type(e).__name__
            if error_type == 'OutOfMemoryError':
                raise RuntimeError(
                    "GPU out of memory while processing with GOT-OCR. "
                    "Try using a smaller image or a different parser."
                )
            elif "bfloat16" in str(e):
                raise RuntimeError(
                    "CUDA device does not support bfloat16. This is a known issue with some GPUs. "
                    "Please try using a different parser or contact support."
                )
            elif "CUDA must not be initialized" in str(e):
                raise RuntimeError(
                    "CUDA initialization error. This is likely due to model loading in the main process. "
                    "In ZeroGPU environments, CUDA must only be initialized within @spaces.GPU decorated functions."
                )
            elif "cannot pickle" in str(e):
                raise RuntimeError(
                    f"Serialization error with ZeroGPU: {str(e)}. "
                    "This may be due to thread locks or other unpicklable objects being passed."
                )
            
            # Generic error
            raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
    
    def _safe_gpu_process(self, image_path: str, use_format: bool, **kwargs):
        """Safe wrapper for GPU processing to avoid pickle issues with thread locks."""
        import pickle
        
        try:
            # Call the GPU-decorated function with minimal, picklable arguments
            return self._process_image_with_gpu(image_path, use_format)
        except pickle.PickleError as e:
            logger.error(f"Pickle error in ZeroGPU processing: {str(e)}")
            # Fall back to CPU processing if pickling fails
            logger.warning("Falling back to CPU processing due to pickling error")
            return self._process_image_without_gpu(image_path, use_format=use_format)
    
    def _process_image_without_gpu(self, image_path: str, use_format: bool = False, **kwargs) -> str:
        """Process an image with GOT-OCR model when not using ZeroGPU."""
        logger.warning("ZeroGPU not available. Using direct model loading, which may not work in Spaces.")
        
        # Import here to avoid CUDA initialization in main process
        import torch
        from transformers import AutoModelForImageTextToText, AutoProcessor
        from transformers.image_utils import load_image
        
        # Load the image
        image = load_image(image_path)
        
        # Load processor and model
        processor = AutoProcessor.from_pretrained(MODEL_NAME)
        
        # Use CPU if in main process to avoid CUDA initialization issues
        device = "cpu"
        model = AutoModelForImageTextToText.from_pretrained(
            MODEL_NAME, 
            low_cpu_mem_usage=True,
            device_map=device
        )
        model = model.eval()
        
        # Process the image based on the selected OCR method
        if use_format:
            # Format mode
            inputs = processor([image], return_tensors="pt", format=True)
            # Keep on CPU to avoid CUDA initialization
            
            # Generate text
            with torch.no_grad():
                generate_ids = model.generate(
                    **inputs,
                    do_sample=False,
                    tokenizer=processor.tokenizer,
                    stop_strings=STOP_STR,
                    max_new_tokens=4096,
                )
            
            # Decode the generated text
            result = processor.decode(
                generate_ids[0, inputs["input_ids"].shape[1]:],
                skip_special_tokens=True,
            )
            
            # Return raw LaTeX output - let post-processing handle conversion
            # This allows for more advanced conversion in the integration module
            logger.info("Returning raw LaTeX output for external processing")
            
        else:
            # Plain text mode
            inputs = processor([image], return_tensors="pt")
            
            # Generate text
            with torch.no_grad():
                generate_ids = model.generate(
                    **inputs,
                    do_sample=False,
                    tokenizer=processor.tokenizer,
                    stop_strings=STOP_STR,
                    max_new_tokens=4096,
                )
            
            # Decode the generated text
            result = processor.decode(
                generate_ids[0, inputs["input_ids"].shape[1]:],
                skip_special_tokens=True,
            )
        
        # Clean up to free memory
        del model
        del processor
        import gc
        gc.collect()
        
        return result.strip()
    
    # Define the GPU-decorated function for ZeroGPU
    if HAS_SPACES:
        @spaces.GPU()  # Use default ZeroGPU allocation timeframe, matching HF implementation
        def _process_image_with_gpu(self, image_path: str, use_format: bool = False) -> str:
            """Process an image with GOT-OCR model using GPU allocation.
            
            IMPORTANT: All model loading and CUDA operations must happen inside this method.
            NOTE: Function must receive only picklable arguments (no thread locks, etc).
            """
            logger.info("Processing with ZeroGPU allocation")
            
            # Imports inside the GPU-decorated function
            import torch
            from transformers import AutoModelForImageTextToText, AutoProcessor
            from transformers.image_utils import load_image
            
            # Load the image
            image = load_image(image_path)
            
            # Now we can load the model inside the GPU-decorated function
            device = "cuda" if torch.cuda.is_available() else "cpu"
            
            logger.info(f"Loading GOT-OCR model from {MODEL_NAME} on {device}")
            
            # Load processor
            processor = AutoProcessor.from_pretrained(MODEL_NAME)
            
            # Load model
            model = AutoModelForImageTextToText.from_pretrained(
                MODEL_NAME, 
                low_cpu_mem_usage=True,
                device_map=device
            )
            
            # Set model to evaluation mode
            model = model.eval()
            
            # Process the image with the model based on the selected OCR method
            if use_format:
                # Format mode (for LaTeX, etc.)
                inputs = processor([image], return_tensors="pt", format=True)
                if torch.cuda.is_available():
                    inputs = inputs.to("cuda")
                
                # Generate text
                with torch.no_grad():
                    generate_ids = model.generate(
                        **inputs,
                        do_sample=False,
                        tokenizer=processor.tokenizer,
                        stop_strings=STOP_STR,
                        max_new_tokens=4096,
                    )
                
                # Decode the generated text
                result = processor.decode(
                    generate_ids[0, inputs["input_ids"].shape[1]:],
                    skip_special_tokens=True,
                )
                
                # Return raw LaTeX output - let post-processing handle conversion
                # This allows for more advanced conversion in the integration module
                logger.info("Returning raw LaTeX output for external processing")
            else:
                # Plain text mode
                inputs = processor([image], return_tensors="pt")
                if torch.cuda.is_available():
                    inputs = inputs.to("cuda")
                
                # Generate text
                with torch.no_grad():
                    generate_ids = model.generate(
                        **inputs,
                        do_sample=False,
                        tokenizer=processor.tokenizer,
                        stop_strings=STOP_STR,
                        max_new_tokens=4096,
                    )
                
                # Decode the generated text
                result = processor.decode(
                    generate_ids[0, inputs["input_ids"].shape[1]:],
                    skip_special_tokens=True,
                )
            
            # Clean up the result
            if result.endswith(STOP_STR):
                result = result[:-len(STOP_STR)]
            
            # Clean up to free memory
            del model
            del processor
            import gc
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                logger.info("CUDA cache cleared")
            
            return result.strip()
    else:
        # Define a dummy method if spaces is not available
        def _process_image_with_gpu(self, image_path: str, use_format: bool = False) -> str:
            # This should never be called if HAS_SPACES is False
            return self._process_image_without_gpu(
                image_path, 
                use_format=use_format
            )
    
    @classmethod
    def release_model(cls):
        """Release model resources - not needed with new implementation."""
        logger.info("Model resources managed by ZeroGPU decorator")

# Try to register the parser
try:
    # Only check basic imports, no CUDA initialization
    import torch
    import transformers
    import pickle  # Import pickle for serialization error handling
    ParserRegistry.register(GotOcrParser)
    logger.info("GOT-OCR parser registered successfully")
except ImportError as e:
    logger.warning(f"Could not register GOT-OCR parser: {str(e)}")