AnseMin commited on
Commit
0f5865d
·
1 Parent(s): ad248f7

change in strategy --implementing github got ocr instead of hugging face model

Browse files
Files changed (4) hide show
  1. app.py +23 -0
  2. requirements.txt +4 -1
  3. setup.sh +10 -3
  4. src/parsers/got_ocr_parser.py +142 -168
app.py CHANGED
@@ -18,6 +18,29 @@ try:
18
  except Exception as e:
19
  print(f"Error running setup.sh: {e}")
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Try to load environment variables from .env file
22
  try:
23
  from dotenv import load_dotenv
 
18
  except Exception as e:
19
  print(f"Error running setup.sh: {e}")
20
 
21
+ # Check if git is installed (needed for GOT-OCR)
22
+ try:
23
+ git_version = subprocess.run(["git", "--version"], capture_output=True, text=True, check=False)
24
+ if git_version.returncode == 0:
25
+ print(f"Git found: {git_version.stdout.strip()}")
26
+ else:
27
+ print("WARNING: Git not found. GOT-OCR parser requires git for repository cloning.")
28
+ except Exception:
29
+ print("WARNING: Git not found or not in PATH. GOT-OCR parser requires git for repository cloning.")
30
+
31
+ # Check if Hugging Face CLI is installed (needed for GOT-OCR)
32
+ try:
33
+ hf_cli = subprocess.run(["huggingface-cli", "--version"], capture_output=True, text=True, check=False)
34
+ if hf_cli.returncode == 0:
35
+ print(f"Hugging Face CLI found: {hf_cli.stdout.strip()}")
36
+ else:
37
+ print("WARNING: Hugging Face CLI not found. GOT-OCR parser requires huggingface-cli for model downloads.")
38
+ print("Installing Hugging Face CLI...")
39
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "huggingface_hub[cli]"], check=False)
40
+ except Exception:
41
+ print("WARNING: Hugging Face CLI not found. Installing...")
42
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "huggingface_hub[cli]"], check=False)
43
+
44
  # Try to load environment variables from .env file
45
  try:
46
  from dotenv import load_dotenv
requirements.txt CHANGED
@@ -10,6 +10,8 @@ opencv-python-headless>=4.5.0 # Headless version for server environments
10
  # Utility dependencies
11
  python-dotenv>=1.0.0
12
  pydantic==2.7.1
 
 
13
 
14
  # Gemini API client
15
  google-genai>=0.1.0
@@ -21,4 +23,5 @@ transformers==4.37.2 # Pin to a specific version that works with safetensors 0.
21
  tiktoken==0.6.0
22
  verovio==4.3.1
23
  accelerate==0.28.0
24
- safetensors==0.4.3 # Updated to meet minimum version required by accelerate
 
 
10
  # Utility dependencies
11
  python-dotenv>=1.0.0
12
  pydantic==2.7.1
13
+ gitpython>=3.1.0 # For cloning repositories
14
+ latex2markdown>=0.1.0 # For LaTeX to Markdown conversion
15
 
16
  # Gemini API client
17
  google-genai>=0.1.0
 
23
  tiktoken==0.6.0
24
  verovio==4.3.1
25
  accelerate==0.28.0
26
+ safetensors==0.4.3 # Updated to meet minimum version required by accelerate
27
+ huggingface_hub[cli]>=0.19.0 # For downloading models from Hugging Face
setup.sh CHANGED
@@ -11,11 +11,12 @@ if [ "$EUID" -eq 0 ]; then
11
  echo "Installing system dependencies..."
12
  apt-get update && apt-get install -y \
13
  wget \
14
- pkg-config
 
15
  echo "System dependencies installed successfully"
16
  else
17
  echo "Not running as root. Skipping system dependencies installation."
18
- echo "If system dependencies are needed, please run this script with sudo."
19
  fi
20
 
21
  # Install NumPy first as it's required by many other packages
@@ -27,13 +28,19 @@ echo "NumPy installed successfully"
27
  echo "Installing Python dependencies..."
28
  pip install -q -U pillow opencv-python-headless
29
  pip install -q -U google-genai
 
30
  echo "Python dependencies installed successfully"
31
 
32
  # Install GOT-OCR dependencies
33
  echo "Installing GOT-OCR dependencies..."
34
- pip install -q -U torch==2.0.1 torchvision==0.15.2 transformers==4.37.2 tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0 safetensors==0.4.3
35
  echo "GOT-OCR dependencies installed successfully"
36
 
 
 
 
 
 
37
  # Install the project in development mode only if setup.py or pyproject.toml exists
38
  if [ -f "setup.py" ] || [ -f "pyproject.toml" ]; then
39
  echo "Installing project in development mode..."
 
11
  echo "Installing system dependencies..."
12
  apt-get update && apt-get install -y \
13
  wget \
14
+ pkg-config \
15
+ git
16
  echo "System dependencies installed successfully"
17
  else
18
  echo "Not running as root. Skipping system dependencies installation."
19
+ echo "Make sure git is installed on your system for GOT-OCR to work properly."
20
  fi
21
 
22
  # Install NumPy first as it's required by many other packages
 
28
  echo "Installing Python dependencies..."
29
  pip install -q -U pillow opencv-python-headless
30
  pip install -q -U google-genai
31
+ pip install -q -U latex2markdown
32
  echo "Python dependencies installed successfully"
33
 
34
  # Install GOT-OCR dependencies
35
  echo "Installing GOT-OCR dependencies..."
36
+ pip install -q -U torch==2.0.1 torchvision==0.15.2 transformers==4.37.2 tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0 safetensors==0.4.3 huggingface_hub
37
  echo "GOT-OCR dependencies installed successfully"
38
 
39
+ # Install Hugging Face CLI
40
+ echo "Installing Hugging Face CLI..."
41
+ pip install -q -U "huggingface_hub[cli]"
42
+ echo "Hugging Face CLI installed successfully"
43
+
44
  # Install the project in development mode only if setup.py or pyproject.toml exists
45
  if [ -f "setup.py" ] || [ -f "pyproject.toml" ]; then
46
  echo "Installing project in development mode..."
src/parsers/got_ocr_parser.py CHANGED
@@ -1,28 +1,31 @@
1
  from pathlib import Path
2
- from typing import Dict, List, Optional, Any, Union
3
- import logging
4
  import os
 
5
  import sys
6
-
7
- # Set PyTorch environment variables for T4 compatibility
8
- os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX"
9
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
10
- os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
11
 
12
  from src.parsers.parser_interface import DocumentParser
13
  from src.parsers.parser_registry import ParserRegistry
14
- from src.utils.latex_converter import latex_to_markdown
 
 
15
 
16
  # Configure logging
17
  logger = logging.getLogger(__name__)
18
 
19
  class GotOcrParser(DocumentParser):
20
- """Parser implementation using GOT-OCR 2.0 for document text extraction.
21
- Optimized for NVIDIA T4 GPUs with explicit float16 support.
 
 
22
  """
23
 
24
- _model = None
25
- _tokenizer = None
 
26
 
27
  @classmethod
28
  def get_name(cls) -> str:
@@ -51,7 +54,6 @@ class GotOcrParser(DocumentParser):
51
  def _check_dependencies(cls) -> bool:
52
  """Check if all required dependencies are installed."""
53
  try:
54
- import numpy
55
  import torch
56
  import transformers
57
  import tiktoken
@@ -60,96 +62,76 @@ class GotOcrParser(DocumentParser):
60
  if hasattr(torch, 'cuda') and not torch.cuda.is_available():
61
  logger.warning("CUDA is not available. GOT-OCR performs best with GPU acceleration.")
62
 
 
 
 
 
 
 
 
 
 
 
 
63
  return True
64
  except ImportError as e:
65
  logger.error(f"Missing dependency: {e}")
66
  return False
67
 
68
  @classmethod
69
- def _load_model(cls):
70
- """Load the GOT-OCR model and tokenizer if not already loaded."""
71
- if cls._model is None or cls._tokenizer is None:
72
- try:
73
- # Import dependencies inside the method to avoid global import errors
74
- import torch
75
- from transformers import AutoModel, AutoTokenizer
76
-
77
- logger.info("Loading GOT-OCR model and tokenizer...")
78
-
79
- # Load tokenizer
80
- cls._tokenizer = AutoTokenizer.from_pretrained(
81
- 'stepfun-ai/GOT-OCR2_0',
82
- trust_remote_code=True
 
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Determine device
86
- device_map = 'cuda' if torch.cuda.is_available() else 'auto'
87
- if device_map == 'cuda':
88
- logger.info("Using CUDA for model inference")
89
- else:
90
- logger.warning("Using CPU for model inference (not recommended)")
91
-
92
- # Load model with explicit float16 for T4 compatibility
93
- cls._model = AutoModel.from_pretrained(
94
- 'stepfun-ai/GOT-OCR2_0',
95
- trust_remote_code=True,
96
- low_cpu_mem_usage=True,
97
- device_map=device_map,
98
- use_safetensors=True,
99
- torch_dtype=torch.float16, # Force float16 for T4 compatibility
100
- pad_token_id=cls._tokenizer.eos_token_id
101
  )
102
 
103
- # Explicitly convert model to half precision (float16)
104
- cls._model = cls._model.half().eval()
105
-
106
- # Move to CUDA if available
107
- if device_map == 'cuda':
108
- cls._model = cls._model.cuda()
109
-
110
- # Patch torch.autocast to force float16 instead of bfloat16
111
- # This fixes the issue in the model's chat method (line 581)
112
- original_autocast = torch.autocast
113
- def patched_autocast(*args, **kwargs):
114
- # Force dtype to float16 when CUDA is involved
115
- if args and args[0] == "cuda":
116
- kwargs['dtype'] = torch.float16
117
- return original_autocast(*args, **kwargs)
118
-
119
- # Apply the patch
120
- torch.autocast = patched_autocast
121
- logger.info("Patched torch.autocast to always use float16 for CUDA operations")
122
-
123
- logger.info("GOT-OCR model loaded successfully")
124
- return True
125
- except Exception as e:
126
- cls._model = None
127
- cls._tokenizer = None
128
- logger.error(f"Failed to load GOT-OCR model: {str(e)}")
129
- return False
130
- return True
131
-
132
- @classmethod
133
- def release_model(cls):
134
- """Release the model from memory."""
135
- try:
136
- import torch
137
 
138
- if cls._model is not None:
139
- del cls._model
140
- cls._model = None
141
-
142
- if cls._tokenizer is not None:
143
- del cls._tokenizer
144
- cls._tokenizer = None
145
-
146
- # Clear CUDA cache if available
147
- if torch.cuda.is_available():
148
- torch.cuda.empty_cache()
149
 
150
- logger.info("GOT-OCR model released from memory")
151
  except Exception as e:
152
- logger.error(f"Error releasing model: {str(e)}")
 
153
 
154
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
155
  """Parse a document using GOT-OCR 2.0.
@@ -170,12 +152,9 @@ class GotOcrParser(DocumentParser):
170
  "tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0"
171
  )
172
 
173
- # Load model if not already loaded
174
- if not self._load_model():
175
- raise RuntimeError("Failed to load GOT-OCR model")
176
-
177
- # Import torch here to ensure it's available
178
- import torch
179
 
180
  # Validate file path and extension
181
  file_path = Path(file_path)
@@ -192,87 +171,76 @@ class GotOcrParser(DocumentParser):
192
  ocr_type = "format" if ocr_method == "format" else "ocr"
193
  logger.info(f"Using OCR method: {ocr_type}")
194
 
195
- # Process the image
 
 
 
196
  try:
197
  logger.info(f"Processing image with GOT-OCR: {file_path}")
198
 
199
- # First attempt: Normal processing with autocast
200
- try:
201
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
202
- # Use format=True parameter when ocr_type is "format"
203
- if ocr_type == "format":
204
- result = self._model.chat(
205
- self._tokenizer,
206
- str(file_path),
207
- ocr_type='format'
208
- )
209
- else:
210
- result = self._model.chat(
211
- self._tokenizer,
212
- str(file_path),
213
- ocr_type='ocr'
214
- )
215
-
216
- # Convert LaTeX to Markdown for better display in UI
217
- if ocr_type == "format":
218
- logger.info("Converting formatted LaTeX output to Markdown")
219
- result = latex_to_markdown(result)
220
-
221
- return result
222
- except RuntimeError as e:
223
- # Check if it's a bfloat16 error
224
- if "bfloat16" in str(e) or "BFloat16" in str(e):
225
- logger.warning("Encountered bfloat16 error, trying float16 fallback")
226
-
227
- # Second attempt: More aggressive float16 forcing
228
- try:
229
- # Ensure model is float16
230
- self._model = self._model.half()
231
-
232
- # Set default dtype temporarily
233
- original_dtype = torch.get_default_dtype()
234
- torch.set_default_dtype(torch.float16)
235
-
236
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
237
- # Use format=True parameter when ocr_type is "format"
238
- if ocr_type == "format":
239
- result = self._model.chat(
240
- self._tokenizer,
241
- str(file_path),
242
- ocr_type='format'
243
- )
244
- else:
245
- result = self._model.chat(
246
- self._tokenizer,
247
- str(file_path),
248
- ocr_type='ocr'
249
- )
250
-
251
- # Restore default dtype
252
- torch.set_default_dtype(original_dtype)
253
-
254
- # Convert LaTeX to Markdown for better display in UI
255
- if ocr_type == "format":
256
- logger.info("Converting formatted LaTeX output to Markdown")
257
- result = latex_to_markdown(result)
258
-
259
- return result
260
- except Exception as inner_e:
261
- logger.error(f"Float16 fallback failed: {str(inner_e)}")
262
- raise RuntimeError(
263
- f"Failed to process image with GOT-OCR: {str(inner_e)}"
264
- )
265
  else:
266
- # Not a bfloat16 error, re-raise
267
- raise
268
-
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  except Exception as e:
270
  logger.error(f"Error processing image with GOT-OCR: {str(e)}")
271
 
272
  # Handle specific errors with helpful messages
273
  error_type = type(e).__name__
274
  if error_type == 'OutOfMemoryError':
275
- self.release_model()
276
  raise RuntimeError(
277
  "GPU out of memory while processing with GOT-OCR. "
278
  "Try using a smaller image or a different parser."
@@ -280,11 +248,17 @@ class GotOcrParser(DocumentParser):
280
 
281
  # Generic error
282
  raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
 
 
 
 
 
 
 
283
 
284
  # Try to register the parser
285
  try:
286
  # Only check basic imports, detailed dependency check happens in parse method
287
- import numpy
288
  import torch
289
  ParserRegistry.register(GotOcrParser)
290
  logger.info("GOT-OCR parser registered successfully")
 
1
  from pathlib import Path
 
 
2
  import os
3
+ import logging
4
  import sys
5
+ import subprocess
6
+ import tempfile
7
+ import shutil
8
+ from typing import Dict, List, Optional, Any, Union
 
9
 
10
  from src.parsers.parser_interface import DocumentParser
11
  from src.parsers.parser_registry import ParserRegistry
12
+
13
+ # Import latex2markdown instead of custom converter
14
+ import latex2markdown
15
 
16
  # Configure logging
17
  logger = logging.getLogger(__name__)
18
 
19
  class GotOcrParser(DocumentParser):
20
+ """Parser implementation using GOT-OCR 2.0 for document text extraction using GitHub repository.
21
+
22
+ This implementation uses the official GOT-OCR2.0 GitHub repository through subprocess calls
23
+ rather than loading the model directly through Hugging Face Transformers.
24
  """
25
 
26
+ # Path to the GOT-OCR repository
27
+ _repo_path = None
28
+ _weights_path = None
29
 
30
  @classmethod
31
  def get_name(cls) -> str:
 
54
  def _check_dependencies(cls) -> bool:
55
  """Check if all required dependencies are installed."""
56
  try:
 
57
  import torch
58
  import transformers
59
  import tiktoken
 
62
  if hasattr(torch, 'cuda') and not torch.cuda.is_available():
63
  logger.warning("CUDA is not available. GOT-OCR performs best with GPU acceleration.")
64
 
65
+ # Check for latex2markdown
66
+ try:
67
+ import latex2markdown
68
+ logger.info("latex2markdown package found")
69
+ except ImportError:
70
+ logger.warning("latex2markdown package not found. Installing...")
71
+ subprocess.run(
72
+ [sys.executable, "-m", "pip", "install", "latex2markdown"],
73
+ check=True
74
+ )
75
+
76
  return True
77
  except ImportError as e:
78
  logger.error(f"Missing dependency: {e}")
79
  return False
80
 
81
  @classmethod
82
+ def _setup_repository(cls) -> bool:
83
+ """Set up the GOT-OCR2.0 repository if it's not already set up."""
84
+ if cls._repo_path is not None and os.path.exists(cls._repo_path):
85
+ return True
86
+
87
+ try:
88
+ # Create a temporary directory for the repository
89
+ repo_dir = os.path.join(tempfile.gettempdir(), "GOT-OCR2.0")
90
+
91
+ # Check if the repository already exists
92
+ if not os.path.exists(repo_dir):
93
+ logger.info("Cloning GOT-OCR2.0 repository...")
94
+ subprocess.run(
95
+ ["git", "clone", "https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git", repo_dir],
96
+ check=True
97
  )
98
+ else:
99
+ logger.info("GOT-OCR2.0 repository already exists, skipping clone")
100
+
101
+ cls._repo_path = repo_dir
102
+
103
+ # Set up the weights directory
104
+ weights_dir = os.path.join(repo_dir, "GOT_weights")
105
+ if not os.path.exists(weights_dir):
106
+ os.makedirs(weights_dir, exist_ok=True)
107
+
108
+ cls._weights_path = weights_dir
109
+
110
+ # Check if weights exist, if not download them
111
+ weight_files = [f for f in os.listdir(weights_dir) if f.endswith(".bin") or f.endswith(".safetensors")]
112
+ if not weight_files:
113
+ logger.info("Downloading GOT-OCR2.0 weights...")
114
+ logger.info("This may take some time depending on your internet connection.")
115
+ logger.info("Downloading from Hugging Face repository...")
116
 
117
+ # Use Hugging Face CLI to download the model
118
+ subprocess.run(
119
+ ["huggingface-cli", "download", "stepfun-ai/GOT-OCR2_0", "--local-dir", weights_dir],
120
+ check=True
 
 
 
 
 
 
 
 
 
 
 
 
121
  )
122
 
123
+ # Additional check to verify downloads
124
+ weight_files = [f for f in os.listdir(weights_dir) if f.endswith(".bin") or f.endswith(".safetensors")]
125
+ if not weight_files:
126
+ logger.error("Failed to download weights. Please download them manually and place in GOT_weights directory.")
127
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ logger.info("GOT-OCR2.0 repository and weights set up successfully")
130
+ return True
 
 
 
 
 
 
 
 
 
131
 
 
132
  except Exception as e:
133
+ logger.error(f"Failed to set up GOT-OCR2.0 repository: {str(e)}")
134
+ return False
135
 
136
  def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
137
  """Parse a document using GOT-OCR 2.0.
 
152
  "tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0"
153
  )
154
 
155
+ # Set up the repository
156
+ if not self._setup_repository():
157
+ raise RuntimeError("Failed to set up GOT-OCR2.0 repository")
 
 
 
158
 
159
  # Validate file path and extension
160
  file_path = Path(file_path)
 
171
  ocr_type = "format" if ocr_method == "format" else "ocr"
172
  logger.info(f"Using OCR method: {ocr_type}")
173
 
174
+ # Check if render is specified in kwargs
175
+ render = kwargs.get('render', False)
176
+
177
+ # Process the image using the GOT-OCR repository
178
  try:
179
  logger.info(f"Processing image with GOT-OCR: {file_path}")
180
 
181
+ # Create the command for running the GOT-OCR script
182
+ cmd = [
183
+ sys.executable,
184
+ os.path.join(self._repo_path, "GOT", "demo", "run_ocr_2.0.py"),
185
+ "--model-name", self._weights_path,
186
+ "--image-file", str(file_path),
187
+ "--type", ocr_type
188
+ ]
189
+
190
+ # Add render flag if required
191
+ if render:
192
+ cmd.append("--render")
193
+
194
+ # Check if box or color is specified in kwargs
195
+ if 'box' in kwargs and kwargs['box']:
196
+ cmd.extend(["--box", str(kwargs['box'])])
197
+
198
+ if 'color' in kwargs and kwargs['color']:
199
+ cmd.extend(["--color", kwargs['color']])
200
+
201
+ # Run the command and capture output
202
+ logger.info(f"Running command: {' '.join(cmd)}")
203
+ process = subprocess.run(
204
+ cmd,
205
+ check=True,
206
+ capture_output=True,
207
+ text=True
208
+ )
209
+
210
+ # Process the output
211
+ result = process.stdout.strip()
212
+
213
+ # If render was requested, find and return the path to the HTML file
214
+ if render:
215
+ # The rendered results are in /results/demo.html according to the README
216
+ html_result_path = os.path.join(self._repo_path, "results", "demo.html")
217
+ if os.path.exists(html_result_path):
218
+ with open(html_result_path, 'r') as f:
219
+ html_content = f.read()
220
+ return html_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  else:
222
+ logger.warning(f"Rendered HTML file not found at {html_result_path}")
223
+
224
+ # Check if we need to convert from LaTeX to Markdown
225
+ if ocr_type == "format":
226
+ logger.info("Converting formatted LaTeX output to Markdown using latex2markdown")
227
+ # Use the latex2markdown package instead of custom converter
228
+ l2m = latex2markdown.LaTeX2Markdown(result)
229
+ result = l2m.to_markdown()
230
+
231
+ return result
232
+
233
+ except subprocess.CalledProcessError as e:
234
+ logger.error(f"Error running GOT-OCR command: {str(e)}")
235
+ logger.error(f"Stderr: {e.stderr}")
236
+ raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
237
+
238
  except Exception as e:
239
  logger.error(f"Error processing image with GOT-OCR: {str(e)}")
240
 
241
  # Handle specific errors with helpful messages
242
  error_type = type(e).__name__
243
  if error_type == 'OutOfMemoryError':
 
244
  raise RuntimeError(
245
  "GPU out of memory while processing with GOT-OCR. "
246
  "Try using a smaller image or a different parser."
 
248
 
249
  # Generic error
250
  raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
251
+
252
+ @classmethod
253
+ def release_model(cls):
254
+ """Release the model resources."""
255
+ # No need to do anything here since we're not loading the model directly
256
+ # We're using subprocess calls instead
257
+ pass
258
 
259
  # Try to register the parser
260
  try:
261
  # Only check basic imports, detailed dependency check happens in parse method
 
262
  import torch
263
  ParserRegistry.register(GotOcrParser)
264
  logger.info("GOT-OCR parser registered successfully")