Spaces:
Running
on
Zero
Running
on
Zero
Please work
Browse files- app.py +28 -19
- requirements.txt +13 -14
- setup.sh +9 -50
- src/parsers/got_ocr_parser.py +291 -298
- src/utils/__init__.py +0 -5
- src/utils/latex_converter.py +0 -186
app.py
CHANGED
@@ -3,7 +3,6 @@ import os
|
|
3 |
import subprocess
|
4 |
import shutil
|
5 |
from pathlib import Path
|
6 |
-
import urllib.request
|
7 |
import logging
|
8 |
|
9 |
# Configure logging - Add this section to suppress httpx logs
|
@@ -24,28 +23,38 @@ try:
|
|
24 |
except Exception as e:
|
25 |
print(f"Error running setup.sh: {e}")
|
26 |
|
27 |
-
# Check
|
28 |
try:
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
32 |
else:
|
33 |
-
print("WARNING:
|
34 |
-
except
|
35 |
-
print("WARNING:
|
|
|
36 |
|
37 |
-
# Check if
|
38 |
try:
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# Check if spaces module is installed (needed for ZeroGPU)
|
51 |
try:
|
|
|
3 |
import subprocess
|
4 |
import shutil
|
5 |
from pathlib import Path
|
|
|
6 |
import logging
|
7 |
|
8 |
# Configure logging - Add this section to suppress httpx logs
|
|
|
23 |
except Exception as e:
|
24 |
print(f"Error running setup.sh: {e}")
|
25 |
|
26 |
+
# Check for PyTorch and CUDA availability (needed for GOT-OCR)
|
27 |
try:
|
28 |
+
import torch
|
29 |
+
print(f"PyTorch version: {torch.__version__}")
|
30 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
33 |
+
print(f"CUDA version: {torch.version.cuda}")
|
34 |
else:
|
35 |
+
print("WARNING: CUDA not available. GOT-OCR performs best with GPU acceleration.")
|
36 |
+
except ImportError:
|
37 |
+
print("WARNING: PyTorch not installed. Installing PyTorch...")
|
38 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "torch", "torchvision"], check=False)
|
39 |
|
40 |
+
# Check if transformers is installed (needed for GOT-OCR)
|
41 |
try:
|
42 |
+
import transformers
|
43 |
+
print(f"Transformers version: {transformers.__version__}")
|
44 |
+
except ImportError:
|
45 |
+
print("WARNING: Transformers not installed. Installing transformers from GitHub...")
|
46 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "git+https://github.com/huggingface/transformers.git@main", "accelerate", "verovio"], check=False)
|
47 |
+
|
48 |
+
# Check if numpy is installed with the correct version
|
49 |
+
try:
|
50 |
+
import numpy as np
|
51 |
+
print(f"NumPy version: {np.__version__}")
|
52 |
+
if np.__version__ != "1.26.3":
|
53 |
+
print("WARNING: NumPy version mismatch. Installing exact version 1.26.3...")
|
54 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "numpy==1.26.3"], check=False)
|
55 |
+
except ImportError:
|
56 |
+
print("WARNING: NumPy not installed. Installing NumPy 1.26.3...")
|
57 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "numpy==1.26.3"], check=False)
|
58 |
|
59 |
# Check if spaces module is installed (needed for ZeroGPU)
|
60 |
try:
|
requirements.txt
CHANGED
@@ -1,28 +1,27 @@
|
|
1 |
# Core dependencies
|
2 |
gradio==5.14.0
|
3 |
markdown==3.7
|
4 |
-
|
5 |
-
numpy
|
6 |
-
|
|
|
|
|
7 |
|
8 |
# Image processing
|
9 |
-
opencv-python
|
10 |
|
11 |
# Utility dependencies
|
12 |
python-dotenv>=1.0.0
|
13 |
pydantic==2.7.1
|
14 |
-
gitpython>=3.1.0 # For cloning repositories
|
15 |
latex2markdown>=0.1.0 # For LaTeX to Markdown conversion
|
16 |
|
17 |
# Gemini API client
|
18 |
google-genai>=0.1.0
|
19 |
|
20 |
-
# GOT-OCR dependencies
|
21 |
-
torch
|
22 |
-
torchvision
|
23 |
-
transformers
|
24 |
-
|
25 |
-
verovio
|
26 |
-
|
27 |
-
safetensors==0.4.3 # Updated to meet minimum version required by accelerate
|
28 |
-
huggingface_hub[cli]>=0.19.0 # For downloading models from Hugging Face
|
|
|
1 |
# Core dependencies
|
2 |
gradio==5.14.0
|
3 |
markdown==3.7
|
4 |
+
pillow # Match exact dependency from GOT-OCR
|
5 |
+
numpy==1.26.3 # Match exact dependency from GOT-OCR
|
6 |
+
|
7 |
+
# For ZeroGPU support
|
8 |
+
spaces
|
9 |
|
10 |
# Image processing
|
11 |
+
opencv-python # Match exact dependency from GOT-OCR
|
12 |
|
13 |
# Utility dependencies
|
14 |
python-dotenv>=1.0.0
|
15 |
pydantic==2.7.1
|
|
|
16 |
latex2markdown>=0.1.0 # For LaTeX to Markdown conversion
|
17 |
|
18 |
# Gemini API client
|
19 |
google-genai>=0.1.0
|
20 |
|
21 |
+
# GOT-OCR dependencies - exactly as in original
|
22 |
+
torch
|
23 |
+
torchvision
|
24 |
+
git+https://github.com/huggingface/transformers.git@main
|
25 |
+
accelerate
|
26 |
+
verovio # Added missing dependency
|
27 |
+
huggingface_hub[cli]>=0.19.0
|
|
|
|
setup.sh
CHANGED
@@ -14,13 +14,10 @@ if [ "$EUID" -eq 0 ]; then
|
|
14 |
echo "Installing system dependencies..."
|
15 |
apt-get update && apt-get install -y \
|
16 |
wget \
|
17 |
-
pkg-config
|
18 |
-
git \
|
19 |
-
tree # Add tree for directory structure visualization
|
20 |
echo "System dependencies installed successfully"
|
21 |
else
|
22 |
echo "Not running as root. Skipping system dependencies installation."
|
23 |
-
echo "Make sure git is installed on your system for GOT-OCR to work properly."
|
24 |
fi
|
25 |
|
26 |
# Install NumPy first as it's required by many other packages
|
@@ -30,62 +27,24 @@ echo "NumPy installed successfully"
|
|
30 |
|
31 |
# Install Python dependencies
|
32 |
echo "Installing Python dependencies..."
|
33 |
-
pip install -q -U pillow opencv-python
|
34 |
pip install -q -U google-genai
|
35 |
pip install -q -U latex2markdown
|
36 |
echo "Python dependencies installed successfully"
|
37 |
|
38 |
-
# Install GOT-OCR dependencies
|
39 |
-
echo "Installing GOT-OCR dependencies..."
|
40 |
-
pip install -q -U torch
|
41 |
-
|
42 |
-
|
43 |
-
#
|
44 |
-
echo "
|
45 |
-
pip install -q -U "huggingface_hub[cli]"
|
46 |
-
echo "Hugging Face CLI installed successfully"
|
47 |
|
48 |
# Install spaces module for ZeroGPU support
|
49 |
echo "Installing spaces module for ZeroGPU support..."
|
50 |
pip install -q -U spaces
|
51 |
echo "Spaces module installed successfully"
|
52 |
|
53 |
-
# Add debug section for GOT-OCR repo
|
54 |
-
echo "===== GOT-OCR Repository Debugging ====="
|
55 |
-
|
56 |
-
# Clone the repository for inspection (if it doesn't exist)
|
57 |
-
TEMP_DIR="/tmp"
|
58 |
-
REPO_DIR="${TEMP_DIR}/GOT-OCR2.0"
|
59 |
-
|
60 |
-
if [ ! -d "$REPO_DIR" ]; then
|
61 |
-
echo "Cloning GOT-OCR2.0 repository for debugging..."
|
62 |
-
git clone https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git "$REPO_DIR"
|
63 |
-
else
|
64 |
-
echo "GOT-OCR2.0 repository already exists at $REPO_DIR"
|
65 |
-
fi
|
66 |
-
|
67 |
-
# Check the repository structure
|
68 |
-
echo "GOT-OCR2.0 repository structure:"
|
69 |
-
if command -v tree &> /dev/null; then
|
70 |
-
tree -L 3 "$REPO_DIR"
|
71 |
-
else
|
72 |
-
find "$REPO_DIR" -type d -maxdepth 3 | sort
|
73 |
-
fi
|
74 |
-
|
75 |
-
# Check if the demo script exists
|
76 |
-
DEMO_SCRIPT="${REPO_DIR}/GOT/demo/run_ocr_2.0.py"
|
77 |
-
if [ -f "$DEMO_SCRIPT" ]; then
|
78 |
-
echo "Demo script found at: $DEMO_SCRIPT"
|
79 |
-
else
|
80 |
-
echo "ERROR: Demo script not found at: $DEMO_SCRIPT"
|
81 |
-
|
82 |
-
# Search for the script in the repository
|
83 |
-
echo "Searching for run_ocr_2.0.py in the repository..."
|
84 |
-
find "$REPO_DIR" -name "run_ocr_2.0.py" -type f
|
85 |
-
fi
|
86 |
-
|
87 |
-
echo "===== End of GOT-OCR Debugging ====="
|
88 |
-
|
89 |
# Install the project in development mode only if setup.py or pyproject.toml exists
|
90 |
if [ -f "setup.py" ] || [ -f "pyproject.toml" ]; then
|
91 |
echo "Installing project in development mode..."
|
|
|
14 |
echo "Installing system dependencies..."
|
15 |
apt-get update && apt-get install -y \
|
16 |
wget \
|
17 |
+
pkg-config
|
|
|
|
|
18 |
echo "System dependencies installed successfully"
|
19 |
else
|
20 |
echo "Not running as root. Skipping system dependencies installation."
|
|
|
21 |
fi
|
22 |
|
23 |
# Install NumPy first as it's required by many other packages
|
|
|
27 |
|
28 |
# Install Python dependencies
|
29 |
echo "Installing Python dependencies..."
|
30 |
+
pip install -q -U pillow opencv-python
|
31 |
pip install -q -U google-genai
|
32 |
pip install -q -U latex2markdown
|
33 |
echo "Python dependencies installed successfully"
|
34 |
|
35 |
+
# Install GOT-OCR transformers dependencies
|
36 |
+
echo "Installing GOT-OCR transformers dependencies..."
|
37 |
+
pip install -q -U torch torchvision
|
38 |
+
pip install -q -U "git+https://github.com/huggingface/transformers.git@main" accelerate verovio
|
39 |
+
pip install -q -U "huggingface_hub[cli]>=0.19.0"
|
40 |
+
pip install -q -U "numpy==1.26.3" # Exact version as in original
|
41 |
+
echo "GOT-OCR transformers dependencies installed successfully"
|
|
|
|
|
42 |
|
43 |
# Install spaces module for ZeroGPU support
|
44 |
echo "Installing spaces module for ZeroGPU support..."
|
45 |
pip install -q -U spaces
|
46 |
echo "Spaces module installed successfully"
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# Install the project in development mode only if setup.py or pyproject.toml exists
|
49 |
if [ -f "setup.py" ] || [ -f "pyproject.toml" ]; then
|
50 |
echo "Installing project in development mode..."
|
src/parsers/got_ocr_parser.py
CHANGED
@@ -2,7 +2,6 @@ from pathlib import Path
|
|
2 |
import os
|
3 |
import logging
|
4 |
import sys
|
5 |
-
import subprocess
|
6 |
import tempfile
|
7 |
import shutil
|
8 |
from typing import Dict, List, Optional, Any, Union
|
@@ -17,35 +16,23 @@ except ImportError:
|
|
17 |
from src.parsers.parser_interface import DocumentParser
|
18 |
from src.parsers.parser_registry import ParserRegistry
|
19 |
|
20 |
-
# Import latex2markdown
|
21 |
import latex2markdown
|
22 |
|
23 |
# Configure logging
|
24 |
logger = logging.getLogger(__name__)
|
25 |
-
# Set logger level to DEBUG for more verbose output
|
26 |
logger.setLevel(logging.DEBUG)
|
27 |
|
28 |
-
# Add patch for bfloat16 at the module level
|
29 |
-
if 'torch' in sys.modules:
|
30 |
-
torch_module = sys.modules['torch']
|
31 |
-
if hasattr(torch_module, 'bfloat16'):
|
32 |
-
# Create a reference to the original bfloat16 function
|
33 |
-
original_bfloat16 = torch_module.bfloat16
|
34 |
-
# Replace it with float16
|
35 |
-
torch_module.bfloat16 = torch_module.float16
|
36 |
-
logger.info("Patched torch.bfloat16 to use torch.float16 instead")
|
37 |
-
|
38 |
class GotOcrParser(DocumentParser):
|
39 |
-
"""Parser implementation using GOT-OCR 2.0 for document text extraction using
|
40 |
|
41 |
-
This implementation uses the
|
42 |
-
|
43 |
"""
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
_got_parent_dir = None # New variable to store the parent directory of the GOT module
|
49 |
|
50 |
@classmethod
|
51 |
def get_name(cls) -> str:
|
@@ -63,6 +50,21 @@ class GotOcrParser(DocumentParser):
|
|
63 |
"id": "format",
|
64 |
"name": "Formatted Text",
|
65 |
"default_params": {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
}
|
67 |
]
|
68 |
|
@@ -76,130 +78,52 @@ class GotOcrParser(DocumentParser):
|
|
76 |
try:
|
77 |
import torch
|
78 |
import transformers
|
79 |
-
import tiktoken
|
80 |
|
81 |
# Check CUDA availability if using torch
|
82 |
if hasattr(torch, 'cuda') and not torch.cuda.is_available():
|
83 |
logger.warning("CUDA is not available. GOT-OCR performs best with GPU acceleration.")
|
84 |
|
85 |
-
# Check for latex2markdown
|
86 |
-
try:
|
87 |
-
import latex2markdown
|
88 |
-
logger.info("latex2markdown package found")
|
89 |
-
except ImportError:
|
90 |
-
logger.warning("latex2markdown package not found. Installing...")
|
91 |
-
subprocess.run(
|
92 |
-
[sys.executable, "-m", "pip", "install", "latex2markdown"],
|
93 |
-
check=True
|
94 |
-
)
|
95 |
-
|
96 |
return True
|
97 |
except ImportError as e:
|
98 |
logger.error(f"Missing dependency: {e}")
|
99 |
return False
|
100 |
|
101 |
@classmethod
|
102 |
-
def
|
103 |
-
"""
|
104 |
-
if cls.
|
105 |
-
logger.debug(f"Repository already set up at: {cls._repo_path}")
|
106 |
return True
|
107 |
|
108 |
try:
|
109 |
-
|
110 |
-
|
111 |
-
logger.debug(f"Repository directory: {repo_dir}")
|
112 |
-
|
113 |
-
# Check if the repository already exists
|
114 |
-
if not os.path.exists(repo_dir):
|
115 |
-
logger.info(f"Cloning GOT-OCR2.0 repository to {repo_dir}...")
|
116 |
-
subprocess.run(
|
117 |
-
["git", "clone", "https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git", repo_dir],
|
118 |
-
check=True
|
119 |
-
)
|
120 |
-
else:
|
121 |
-
logger.info(f"GOT-OCR2.0 repository already exists at {repo_dir}, skipping clone")
|
122 |
|
123 |
-
|
|
|
124 |
|
125 |
-
#
|
126 |
-
|
127 |
-
try:
|
128 |
-
result = subprocess.run(
|
129 |
-
["find", repo_dir, "-type", "d", "-maxdepth", "3"],
|
130 |
-
check=True,
|
131 |
-
capture_output=True,
|
132 |
-
text=True
|
133 |
-
)
|
134 |
-
for line in result.stdout.splitlines():
|
135 |
-
logger.debug(f" {line}")
|
136 |
-
except Exception as e:
|
137 |
-
logger.warning(f"Could not list repository contents: {e}")
|
138 |
|
139 |
-
|
140 |
-
demo_script = os.path.join(repo_dir, "GOT", "demo", "run_ocr_2.0.py")
|
141 |
-
if os.path.exists(demo_script):
|
142 |
-
logger.info(f"Found demo script at: {demo_script}")
|
143 |
-
cls._got_parent_dir = repo_dir # Parent dir is the repo dir
|
144 |
-
else:
|
145 |
-
logger.warning(f"Demo script not found at expected path: {demo_script}")
|
146 |
-
# Try to find it
|
147 |
-
logger.info("Searching for run_ocr_2.0.py in the repository...")
|
148 |
-
try:
|
149 |
-
find_result = subprocess.run(
|
150 |
-
["find", repo_dir, "-name", "run_ocr_2.0.py", "-type", "f"],
|
151 |
-
check=True,
|
152 |
-
capture_output=True,
|
153 |
-
text=True
|
154 |
-
)
|
155 |
-
if find_result.stdout.strip():
|
156 |
-
found_paths = find_result.stdout.strip().splitlines()
|
157 |
-
logger.info(f"Found script at alternative locations: {found_paths}")
|
158 |
-
# Use the first found path as fallback
|
159 |
-
if found_paths:
|
160 |
-
alternative_path = found_paths[0]
|
161 |
-
logger.info(f"Using alternative path: {alternative_path}")
|
162 |
-
|
163 |
-
# Set the parent directory for the GOT module
|
164 |
-
# We need to set it to the directory that contains the GOT-OCR-2.0-master directory
|
165 |
-
if "GOT-OCR-2.0-master" in alternative_path:
|
166 |
-
cls._got_parent_dir = os.path.join(repo_dir, "GOT-OCR-2.0-master")
|
167 |
-
logger.info(f"Parent directory for GOT module: {cls._got_parent_dir}")
|
168 |
-
except Exception as e:
|
169 |
-
logger.warning(f"Could not search for script: {e}")
|
170 |
|
171 |
-
#
|
172 |
-
|
173 |
-
if not os.path.exists(weights_dir):
|
174 |
-
os.makedirs(weights_dir, exist_ok=True)
|
175 |
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
#
|
180 |
-
|
181 |
-
if not weight_files:
|
182 |
-
logger.info("Downloading GOT-OCR2.0 weights...")
|
183 |
-
logger.info("This may take some time depending on your internet connection.")
|
184 |
-
logger.info("Downloading from Hugging Face repository...")
|
185 |
-
|
186 |
-
# Use Hugging Face CLI to download the model
|
187 |
-
subprocess.run(
|
188 |
-
["huggingface-cli", "download", "stepfun-ai/GOT-OCR2_0", "--local-dir", weights_dir],
|
189 |
-
check=True
|
190 |
-
)
|
191 |
-
|
192 |
-
# Additional check to verify downloads
|
193 |
-
weight_files = [f for f in os.listdir(weights_dir) if f.endswith(".bin") or f.endswith(".safetensors")]
|
194 |
-
if not weight_files:
|
195 |
-
logger.error("Failed to download weights. Please download them manually and place in GOT_weights directory.")
|
196 |
-
return False
|
197 |
|
198 |
-
logger.info("GOT-
|
199 |
return True
|
200 |
|
201 |
except Exception as e:
|
202 |
-
logger.error(f"Failed to
|
203 |
return False
|
204 |
|
205 |
def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
|
@@ -207,8 +131,10 @@ class GotOcrParser(DocumentParser):
|
|
207 |
|
208 |
Args:
|
209 |
file_path: Path to the image file
|
210 |
-
ocr_method: OCR method to use ('plain'
|
211 |
**kwargs: Additional arguments to pass to the model
|
|
|
|
|
212 |
|
213 |
Returns:
|
214 |
Extracted text from the image, converted to Markdown if formatted
|
@@ -217,14 +143,9 @@ class GotOcrParser(DocumentParser):
|
|
217 |
if not self._check_dependencies():
|
218 |
raise ImportError(
|
219 |
"Required dependencies are missing. Please install: "
|
220 |
-
"torch
|
221 |
-
"tiktoken==0.6.0 verovio==4.3.1 accelerate==0.28.0"
|
222 |
)
|
223 |
|
224 |
-
# Set up the repository
|
225 |
-
if not self._setup_repository():
|
226 |
-
raise RuntimeError("Failed to set up GOT-OCR2.0 repository")
|
227 |
-
|
228 |
# Validate file path and extension
|
229 |
file_path = Path(file_path)
|
230 |
if not file_path.exists():
|
@@ -236,149 +157,41 @@ class GotOcrParser(DocumentParser):
|
|
236 |
f"Received file with extension: {file_path.suffix}"
|
237 |
)
|
238 |
|
239 |
-
# Determine OCR
|
240 |
-
|
241 |
-
|
|
|
|
|
242 |
|
243 |
-
#
|
244 |
-
|
245 |
|
246 |
-
#
|
|
|
|
|
|
|
|
|
247 |
try:
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
if found_paths:
|
268 |
-
script_path = found_paths[0]
|
269 |
-
logger.info(f"Found script at alternative location: {script_path}")
|
270 |
-
else:
|
271 |
-
raise FileNotFoundError(f"Could not find run_ocr_2.0.py in repository: {self._repo_path}")
|
272 |
-
except Exception as search_e:
|
273 |
-
logger.error(f"Error searching for script: {str(search_e)}")
|
274 |
-
raise FileNotFoundError(f"Script not found and search failed: {str(search_e)}")
|
275 |
-
|
276 |
-
# Create a batch/shell script to run the Python script with the correct PYTHONPATH
|
277 |
-
# This ensures the GOT module can be imported and patches bfloat16
|
278 |
-
temp_script = None
|
279 |
-
try:
|
280 |
-
# Create a temporary script
|
281 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=False) as f:
|
282 |
-
temp_script = f.name
|
283 |
-
parent_dir = self._got_parent_dir or os.path.dirname(os.path.dirname(script_path))
|
284 |
-
|
285 |
-
# Add commands to the script
|
286 |
-
f.write("#!/bin/bash\n")
|
287 |
-
f.write(f"cd {parent_dir}\n") # Change to parent directory
|
288 |
-
f.write("export PYTHONPATH=$PYTHONPATH:$(pwd)\n") # Add current directory to PYTHONPATH
|
289 |
-
|
290 |
-
# Add a Python script to patch torch.bfloat16
|
291 |
-
patch_script = os.path.join(tempfile.gettempdir(), "patch_torch.py")
|
292 |
-
with open(patch_script, 'w') as patch_f:
|
293 |
-
patch_f.write("""
|
294 |
-
import sys
|
295 |
-
import torch
|
296 |
-
|
297 |
-
# Patch torch.bfloat16 to use torch.float16 instead
|
298 |
-
if hasattr(torch, 'bfloat16'):
|
299 |
-
# Save reference to original bfloat16
|
300 |
-
original_bfloat16 = torch.bfloat16
|
301 |
-
# Replace with float16
|
302 |
-
torch.bfloat16 = torch.float16
|
303 |
-
print("Successfully patched torch.bfloat16 to use torch.float16")
|
304 |
-
|
305 |
-
# Also patch torch.autocast context manager for CUDA
|
306 |
-
original_autocast = torch.autocast
|
307 |
-
def patched_autocast(*args, **kwargs):
|
308 |
-
# Force dtype to float16 when CUDA is involved
|
309 |
-
if args and args[0] == "cuda" and kwargs.get("dtype") == torch.bfloat16:
|
310 |
-
kwargs["dtype"] = torch.float16
|
311 |
-
print(f"Autocast: Changed bfloat16 to float16 for {args}")
|
312 |
-
return original_autocast(*args, **kwargs)
|
313 |
-
|
314 |
-
torch.autocast = patched_autocast
|
315 |
-
print("Successfully patched torch.autocast to ensure float16 is used instead of bfloat16")
|
316 |
-
""")
|
317 |
-
|
318 |
-
# Build the command with the patch included
|
319 |
-
py_cmd = [
|
320 |
-
sys.executable,
|
321 |
-
"-c",
|
322 |
-
f"import sys; sys.path.insert(0, '{parent_dir}'); "
|
323 |
-
f"exec(open('{patch_script}').read()); "
|
324 |
-
f"import runpy; runpy.run_path('{script_path}', run_name='__main__')"
|
325 |
-
]
|
326 |
-
|
327 |
-
# Add the arguments
|
328 |
-
py_cmd.extend(["--model-name", self._weights_path])
|
329 |
-
py_cmd.extend(["--image-file", str(file_path)])
|
330 |
-
py_cmd.extend(["--type", ocr_type])
|
331 |
-
|
332 |
-
# Add render flag if required
|
333 |
-
if render:
|
334 |
-
py_cmd.append("--render")
|
335 |
-
|
336 |
-
# Check if box or color is specified in kwargs
|
337 |
-
if 'box' in kwargs and kwargs['box']:
|
338 |
-
py_cmd.extend(["--box", str(kwargs['box'])])
|
339 |
-
|
340 |
-
if 'color' in kwargs and kwargs['color']:
|
341 |
-
py_cmd.extend(["--color", kwargs['color']])
|
342 |
-
|
343 |
-
# Add the command to the script
|
344 |
-
f.write(" ".join(py_cmd) + "\n")
|
345 |
-
|
346 |
-
# Make the script executable
|
347 |
-
os.chmod(temp_script, 0o755)
|
348 |
-
|
349 |
-
# Run the script with GPU access if available
|
350 |
-
result = self._run_with_gpu(temp_script)
|
351 |
-
|
352 |
-
# If render was requested, find and return the path to the HTML file
|
353 |
-
if render:
|
354 |
-
# The rendered results are in /results/demo.html according to the README
|
355 |
-
html_result_path = os.path.join(self._repo_path, "results", "demo.html")
|
356 |
-
if os.path.exists(html_result_path):
|
357 |
-
with open(html_result_path, 'r') as f:
|
358 |
-
html_content = f.read()
|
359 |
-
return html_content
|
360 |
-
else:
|
361 |
-
logger.warning(f"Rendered HTML file not found at {html_result_path}")
|
362 |
-
|
363 |
-
# Check if we need to convert from LaTeX to Markdown
|
364 |
-
if ocr_type == "format":
|
365 |
-
logger.info("Converting formatted LaTeX output to Markdown using latex2markdown")
|
366 |
-
# Use the latex2markdown package instead of custom converter
|
367 |
-
l2m = latex2markdown.LaTeX2Markdown(result)
|
368 |
-
result = l2m.to_markdown()
|
369 |
-
|
370 |
-
return result
|
371 |
-
|
372 |
-
finally:
|
373 |
-
# Clean up the temporary script
|
374 |
-
if temp_script and os.path.exists(temp_script):
|
375 |
-
os.unlink(temp_script)
|
376 |
|
377 |
-
except subprocess.CalledProcessError as e:
|
378 |
-
logger.error(f"Error running GOT-OCR command: {str(e)}")
|
379 |
-
logger.error(f"Stderr: {e.stderr}")
|
380 |
-
raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
|
381 |
-
|
382 |
except Exception as e:
|
383 |
logger.error(f"Error processing image with GOT-OCR: {str(e)}")
|
384 |
|
@@ -389,57 +202,237 @@ print("Successfully patched torch.autocast to ensure float16 is used instead of
|
|
389 |
"GPU out of memory while processing with GOT-OCR. "
|
390 |
"Try using a smaller image or a different parser."
|
391 |
)
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
# Generic error
|
394 |
raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
|
395 |
-
|
396 |
-
# Define a method that will be decorated with spaces.GPU to ensure GPU access
|
397 |
-
def _run_with_gpu(self, script_path):
|
398 |
-
"""Run a script with GPU access using the spaces.GPU decorator if available."""
|
399 |
-
if HAS_SPACES:
|
400 |
-
# Use the spaces.GPU decorator to ensure GPU access
|
401 |
-
return self._run_script_with_gpu_allocation(script_path)
|
402 |
-
else:
|
403 |
-
# Fall back to regular execution if spaces module is not available
|
404 |
-
logger.info(f"Running command through wrapper script without ZeroGPU: {script_path}")
|
405 |
-
process = subprocess.run(
|
406 |
-
[script_path],
|
407 |
-
check=True,
|
408 |
-
capture_output=True,
|
409 |
-
text=True
|
410 |
-
)
|
411 |
-
return process.stdout.strip()
|
412 |
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
if HAS_SPACES:
|
415 |
@spaces.GPU(duration=180) # Allocate up to 3 minutes for OCR processing
|
416 |
-
def
|
417 |
-
"""
|
418 |
-
logger.info(
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
424 |
)
|
425 |
-
return process.stdout.strip()
|
426 |
else:
|
427 |
# Define a dummy method if spaces is not available
|
428 |
-
def
|
429 |
# This should never be called if HAS_SPACES is False
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
|
432 |
@classmethod
|
433 |
def release_model(cls):
|
434 |
"""Release the model resources."""
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
# Try to register the parser
|
440 |
try:
|
441 |
# Only check basic imports, detailed dependency check happens in parse method
|
442 |
import torch
|
|
|
443 |
ParserRegistry.register(GotOcrParser)
|
444 |
logger.info("GOT-OCR parser registered successfully")
|
445 |
except ImportError as e:
|
|
|
2 |
import os
|
3 |
import logging
|
4 |
import sys
|
|
|
5 |
import tempfile
|
6 |
import shutil
|
7 |
from typing import Dict, List, Optional, Any, Union
|
|
|
16 |
from src.parsers.parser_interface import DocumentParser
|
17 |
from src.parsers.parser_registry import ParserRegistry
|
18 |
|
19 |
+
# Import latex2markdown for conversion
|
20 |
import latex2markdown
|
21 |
|
22 |
# Configure logging
|
23 |
logger = logging.getLogger(__name__)
|
|
|
24 |
logger.setLevel(logging.DEBUG)
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class GotOcrParser(DocumentParser):
|
27 |
+
"""Parser implementation using GOT-OCR 2.0 for document text extraction using transformers.
|
28 |
|
29 |
+
This implementation uses the transformers model directly for better integration with
|
30 |
+
ZeroGPU and avoids subprocess complexity.
|
31 |
"""
|
32 |
|
33 |
+
# Class variables to hold model and processor
|
34 |
+
_model = None
|
35 |
+
_processor = None
|
|
|
36 |
|
37 |
@classmethod
|
38 |
def get_name(cls) -> str:
|
|
|
50 |
"id": "format",
|
51 |
"name": "Formatted Text",
|
52 |
"default_params": {}
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"id": "box",
|
56 |
+
"name": "OCR with Bounding Box",
|
57 |
+
"default_params": {"box": "[100,100,200,200]"} # Default box coordinates
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"id": "color",
|
61 |
+
"name": "OCR with Color Filter",
|
62 |
+
"default_params": {"color": "red"} # Default color filter
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"id": "multi_crop",
|
66 |
+
"name": "Multi-crop OCR",
|
67 |
+
"default_params": {}
|
68 |
}
|
69 |
]
|
70 |
|
|
|
78 |
try:
|
79 |
import torch
|
80 |
import transformers
|
|
|
81 |
|
82 |
# Check CUDA availability if using torch
|
83 |
if hasattr(torch, 'cuda') and not torch.cuda.is_available():
|
84 |
logger.warning("CUDA is not available. GOT-OCR performs best with GPU acceleration.")
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
return True
|
87 |
except ImportError as e:
|
88 |
logger.error(f"Missing dependency: {e}")
|
89 |
return False
|
90 |
|
91 |
@classmethod
|
92 |
+
def _load_model(cls) -> bool:
|
93 |
+
"""Load the GOT-OCR model if it's not already loaded."""
|
94 |
+
if cls._model is not None and cls._processor is not None:
|
|
|
95 |
return True
|
96 |
|
97 |
try:
|
98 |
+
import torch
|
99 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
# Define the model name - using the HF model ID
|
102 |
+
model_name = "stepfun-ai/GOT-OCR-2.0-hf"
|
103 |
|
104 |
+
# Get the device (CUDA if available, otherwise CPU)
|
105 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
+
logger.info(f"Loading GOT-OCR model from {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
# Load processor
|
110 |
+
cls._processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
|
111 |
|
112 |
+
# Load model
|
113 |
+
cls._model = AutoModelForImageTextToText.from_pretrained(
|
114 |
+
model_name,
|
115 |
+
low_cpu_mem_usage=True,
|
116 |
+
device_map=device
|
117 |
+
)
|
118 |
|
119 |
+
# Set model to evaluation mode
|
120 |
+
cls._model = cls._model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
logger.info("GOT-OCR model loaded successfully")
|
123 |
return True
|
124 |
|
125 |
except Exception as e:
|
126 |
+
logger.error(f"Failed to load GOT-OCR model: {str(e)}")
|
127 |
return False
|
128 |
|
129 |
def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
|
|
|
131 |
|
132 |
Args:
|
133 |
file_path: Path to the image file
|
134 |
+
ocr_method: OCR method to use ('plain', 'format', 'box', 'color', 'multi_crop')
|
135 |
**kwargs: Additional arguments to pass to the model
|
136 |
+
- box: For 'box' method, specify box coordinates [x1,y1,x2,y2]
|
137 |
+
- color: For 'color' method, specify color ('red', 'green', 'blue')
|
138 |
|
139 |
Returns:
|
140 |
Extracted text from the image, converted to Markdown if formatted
|
|
|
143 |
if not self._check_dependencies():
|
144 |
raise ImportError(
|
145 |
"Required dependencies are missing. Please install: "
|
146 |
+
"torch transformers"
|
|
|
147 |
)
|
148 |
|
|
|
|
|
|
|
|
|
149 |
# Validate file path and extension
|
150 |
file_path = Path(file_path)
|
151 |
if not file_path.exists():
|
|
|
157 |
f"Received file with extension: {file_path.suffix}"
|
158 |
)
|
159 |
|
160 |
+
# Determine OCR mode and parameters based on method
|
161 |
+
use_format = ocr_method == "format"
|
162 |
+
use_box = ocr_method == "box"
|
163 |
+
use_color = ocr_method == "color"
|
164 |
+
use_multi_crop = ocr_method == "multi_crop"
|
165 |
|
166 |
+
# Log the OCR method being used
|
167 |
+
logger.info(f"Using OCR method: {ocr_method or 'plain'}")
|
168 |
|
169 |
+
# Load the model if it's not already loaded
|
170 |
+
if not self._load_model():
|
171 |
+
raise RuntimeError("Failed to load GOT-OCR model")
|
172 |
+
|
173 |
+
# Process the image using transformers
|
174 |
try:
|
175 |
+
# Use the spaces.GPU decorator if available
|
176 |
+
if HAS_SPACES:
|
177 |
+
return self._process_image_with_gpu(
|
178 |
+
str(file_path),
|
179 |
+
use_format=use_format,
|
180 |
+
use_box=use_box,
|
181 |
+
use_color=use_color,
|
182 |
+
use_multi_crop=use_multi_crop,
|
183 |
+
**kwargs
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
return self._process_image(
|
187 |
+
str(file_path),
|
188 |
+
use_format=use_format,
|
189 |
+
use_box=use_box,
|
190 |
+
use_color=use_color,
|
191 |
+
use_multi_crop=use_multi_crop,
|
192 |
+
**kwargs
|
193 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
|
|
|
|
|
|
|
|
|
|
195 |
except Exception as e:
|
196 |
logger.error(f"Error processing image with GOT-OCR: {str(e)}")
|
197 |
|
|
|
202 |
"GPU out of memory while processing with GOT-OCR. "
|
203 |
"Try using a smaller image or a different parser."
|
204 |
)
|
205 |
+
elif "bfloat16" in str(e):
|
206 |
+
raise RuntimeError(
|
207 |
+
"CUDA device does not support bfloat16. This is a known issue with some GPUs. "
|
208 |
+
"Please try using a different parser or contact support."
|
209 |
+
)
|
210 |
|
211 |
# Generic error
|
212 |
raise RuntimeError(f"Error processing document with GOT-OCR: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
def _process_image(self, image_path: str, use_format: bool = False, use_box: bool = False, use_color: bool = False, use_multi_crop: bool = False, **kwargs) -> str:
|
215 |
+
"""Process an image with GOT-OCR model (no GPU decorator)."""
|
216 |
+
try:
|
217 |
+
from transformers.image_utils import load_image
|
218 |
+
import torch
|
219 |
+
|
220 |
+
# Load the image
|
221 |
+
image = load_image(image_path)
|
222 |
+
|
223 |
+
# Define stop string
|
224 |
+
stop_str = "<|im_end|>"
|
225 |
+
|
226 |
+
# Process the image with the model based on the selected OCR method
|
227 |
+
if use_format:
|
228 |
+
# Format mode (for LaTeX, etc.)
|
229 |
+
inputs = self._processor([image], return_tensors="pt", format=True)
|
230 |
+
if torch.cuda.is_available():
|
231 |
+
inputs = inputs.to("cuda")
|
232 |
+
|
233 |
+
# Generate text
|
234 |
+
with torch.no_grad():
|
235 |
+
generate_ids = self._model.generate(
|
236 |
+
**inputs,
|
237 |
+
do_sample=False,
|
238 |
+
tokenizer=self._processor.tokenizer,
|
239 |
+
stop_strings=stop_str,
|
240 |
+
max_new_tokens=4096,
|
241 |
+
)
|
242 |
+
|
243 |
+
# Decode the generated text
|
244 |
+
result = self._processor.decode(
|
245 |
+
generate_ids[0, inputs["input_ids"].shape[1]:],
|
246 |
+
skip_special_tokens=True,
|
247 |
+
)
|
248 |
+
|
249 |
+
# Convert to Markdown if it's formatted
|
250 |
+
l2m = latex2markdown.LaTeX2Markdown(result)
|
251 |
+
result = l2m.to_markdown()
|
252 |
+
|
253 |
+
elif use_box:
|
254 |
+
# Box-based OCR
|
255 |
+
box_coords = kwargs.get('box', '[100,100,200,200]')
|
256 |
+
if isinstance(box_coords, str):
|
257 |
+
# Convert string representation to list if needed
|
258 |
+
import json
|
259 |
+
try:
|
260 |
+
box_coords = json.loads(box_coords.replace("'", '"'))
|
261 |
+
except json.JSONDecodeError:
|
262 |
+
logger.warning(f"Invalid box format: {box_coords}. Using default.")
|
263 |
+
box_coords = [100, 100, 200, 200]
|
264 |
+
|
265 |
+
logger.info(f"Using box coordinates: {box_coords}")
|
266 |
+
|
267 |
+
# Process with box parameter
|
268 |
+
inputs = self._processor([image], return_tensors="pt", box=box_coords)
|
269 |
+
if torch.cuda.is_available():
|
270 |
+
inputs = inputs.to("cuda")
|
271 |
+
|
272 |
+
# Generate text
|
273 |
+
with torch.no_grad():
|
274 |
+
generate_ids = self._model.generate(
|
275 |
+
**inputs,
|
276 |
+
do_sample=False,
|
277 |
+
tokenizer=self._processor.tokenizer,
|
278 |
+
stop_strings=stop_str,
|
279 |
+
max_new_tokens=4096,
|
280 |
+
)
|
281 |
+
|
282 |
+
# Decode the generated text
|
283 |
+
result = self._processor.decode(
|
284 |
+
generate_ids[0, inputs["input_ids"].shape[1]:],
|
285 |
+
skip_special_tokens=True,
|
286 |
+
)
|
287 |
+
|
288 |
+
elif use_color:
|
289 |
+
# Color-based OCR
|
290 |
+
color = kwargs.get('color', 'red')
|
291 |
+
logger.info(f"Using color filter: {color}")
|
292 |
+
|
293 |
+
# Process with color parameter
|
294 |
+
inputs = self._processor([image], return_tensors="pt", color=color)
|
295 |
+
if torch.cuda.is_available():
|
296 |
+
inputs = inputs.to("cuda")
|
297 |
+
|
298 |
+
# Generate text
|
299 |
+
with torch.no_grad():
|
300 |
+
generate_ids = self._model.generate(
|
301 |
+
**inputs,
|
302 |
+
do_sample=False,
|
303 |
+
tokenizer=self._processor.tokenizer,
|
304 |
+
stop_strings=stop_str,
|
305 |
+
max_new_tokens=4096,
|
306 |
+
)
|
307 |
+
|
308 |
+
# Decode the generated text
|
309 |
+
result = self._processor.decode(
|
310 |
+
generate_ids[0, inputs["input_ids"].shape[1]:],
|
311 |
+
skip_special_tokens=True,
|
312 |
+
)
|
313 |
+
|
314 |
+
elif use_multi_crop:
|
315 |
+
# Multi-crop OCR
|
316 |
+
logger.info("Using multi-crop OCR")
|
317 |
+
|
318 |
+
# Process with multi-crop parameter
|
319 |
+
inputs = self._processor(
|
320 |
+
[image],
|
321 |
+
return_tensors="pt",
|
322 |
+
format=True,
|
323 |
+
crop_to_patches=True,
|
324 |
+
max_patches=5,
|
325 |
+
)
|
326 |
+
if torch.cuda.is_available():
|
327 |
+
inputs = inputs.to("cuda")
|
328 |
+
|
329 |
+
# Generate text
|
330 |
+
with torch.no_grad():
|
331 |
+
generate_ids = self._model.generate(
|
332 |
+
**inputs,
|
333 |
+
do_sample=False,
|
334 |
+
tokenizer=self._processor.tokenizer,
|
335 |
+
stop_strings=stop_str,
|
336 |
+
max_new_tokens=4096,
|
337 |
+
)
|
338 |
+
|
339 |
+
# Decode the generated text
|
340 |
+
result = self._processor.decode(
|
341 |
+
generate_ids[0, inputs["input_ids"].shape[1]:],
|
342 |
+
skip_special_tokens=True,
|
343 |
+
)
|
344 |
+
|
345 |
+
# Convert to Markdown as multi-crop uses format
|
346 |
+
l2m = latex2markdown.LaTeX2Markdown(result)
|
347 |
+
result = l2m.to_markdown()
|
348 |
+
|
349 |
+
else:
|
350 |
+
# Plain text mode
|
351 |
+
inputs = self._processor([image], return_tensors="pt")
|
352 |
+
if torch.cuda.is_available():
|
353 |
+
inputs = inputs.to("cuda")
|
354 |
+
|
355 |
+
# Generate text
|
356 |
+
with torch.no_grad():
|
357 |
+
generate_ids = self._model.generate(
|
358 |
+
**inputs,
|
359 |
+
do_sample=False,
|
360 |
+
tokenizer=self._processor.tokenizer,
|
361 |
+
stop_strings=stop_str,
|
362 |
+
max_new_tokens=4096,
|
363 |
+
)
|
364 |
+
|
365 |
+
# Decode the generated text
|
366 |
+
result = self._processor.decode(
|
367 |
+
generate_ids[0, inputs["input_ids"].shape[1]:],
|
368 |
+
skip_special_tokens=True,
|
369 |
+
)
|
370 |
+
|
371 |
+
# Clean up the result
|
372 |
+
if result.endswith(stop_str):
|
373 |
+
result = result[:-len(stop_str)]
|
374 |
+
|
375 |
+
return result.strip()
|
376 |
+
|
377 |
+
except Exception as e:
|
378 |
+
logger.error(f"Error in _process_image: {str(e)}")
|
379 |
+
raise
|
380 |
+
|
381 |
+
# Define the GPU-decorated function for ZeroGPU
|
382 |
if HAS_SPACES:
|
383 |
@spaces.GPU(duration=180) # Allocate up to 3 minutes for OCR processing
|
384 |
+
def _process_image_with_gpu(self, image_path: str, use_format: bool = False, use_box: bool = False, use_color: bool = False, use_multi_crop: bool = False, **kwargs) -> str:
|
385 |
+
"""Process an image with GOT-OCR model using GPU allocation."""
|
386 |
+
logger.info("Processing with ZeroGPU allocation")
|
387 |
+
return self._process_image(
|
388 |
+
image_path,
|
389 |
+
use_format=use_format,
|
390 |
+
use_box=use_box,
|
391 |
+
use_color=use_color,
|
392 |
+
use_multi_crop=use_multi_crop,
|
393 |
+
**kwargs
|
394 |
)
|
|
|
395 |
else:
|
396 |
# Define a dummy method if spaces is not available
|
397 |
+
def _process_image_with_gpu(self, image_path: str, use_format: bool = False, use_box: bool = False, use_color: bool = False, use_multi_crop: bool = False, **kwargs) -> str:
|
398 |
# This should never be called if HAS_SPACES is False
|
399 |
+
return self._process_image(
|
400 |
+
image_path,
|
401 |
+
use_format=use_format,
|
402 |
+
use_box=use_box,
|
403 |
+
use_color=use_color,
|
404 |
+
use_multi_crop=use_multi_crop,
|
405 |
+
**kwargs
|
406 |
+
)
|
407 |
|
408 |
@classmethod
|
409 |
def release_model(cls):
|
410 |
"""Release the model resources."""
|
411 |
+
if cls._model is not None:
|
412 |
+
# Clear the model from memory
|
413 |
+
cls._model = None
|
414 |
+
cls._processor = None
|
415 |
+
|
416 |
+
# Force garbage collection
|
417 |
+
import gc
|
418 |
+
gc.collect()
|
419 |
+
|
420 |
+
# Clear CUDA cache if available
|
421 |
+
try:
|
422 |
+
import torch
|
423 |
+
if torch.cuda.is_available():
|
424 |
+
torch.cuda.empty_cache()
|
425 |
+
logger.info("CUDA cache cleared")
|
426 |
+
except ImportError:
|
427 |
+
pass
|
428 |
+
|
429 |
+
logger.info("GOT-OCR model resources released")
|
430 |
|
431 |
# Try to register the parser
|
432 |
try:
|
433 |
# Only check basic imports, detailed dependency check happens in parse method
|
434 |
import torch
|
435 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
436 |
ParserRegistry.register(GotOcrParser)
|
437 |
logger.info("GOT-OCR parser registered successfully")
|
438 |
except ImportError as e:
|
src/utils/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
"""Utility functions for the Markit application."""
|
2 |
-
|
3 |
-
from src.utils.latex_converter import latex_to_markdown
|
4 |
-
|
5 |
-
__all__ = ['latex_to_markdown']
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/latex_converter.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import logging
|
3 |
-
|
4 |
-
# Configure logging
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
-
|
7 |
-
def latex_to_markdown(latex_text):
|
8 |
-
"""
|
9 |
-
Convert LaTeX formatted text from GOT-OCR to Markdown.
|
10 |
-
|
11 |
-
Args:
|
12 |
-
latex_text (str): LaTeX formatted text
|
13 |
-
|
14 |
-
Returns:
|
15 |
-
str: Markdown formatted text
|
16 |
-
"""
|
17 |
-
if not latex_text:
|
18 |
-
return ""
|
19 |
-
|
20 |
-
logger.info("Converting LaTeX to Markdown")
|
21 |
-
|
22 |
-
# Make a copy of the input text
|
23 |
-
md_text = latex_text
|
24 |
-
|
25 |
-
# Handle LaTeX tables
|
26 |
-
md_text = convert_latex_tables(md_text)
|
27 |
-
|
28 |
-
# Handle LaTeX math environments
|
29 |
-
md_text = convert_math_environments(md_text)
|
30 |
-
|
31 |
-
# Handle LaTeX formatting commands
|
32 |
-
md_text = convert_formatting_commands(md_text)
|
33 |
-
|
34 |
-
# Handle LaTeX lists
|
35 |
-
md_text = convert_latex_lists(md_text)
|
36 |
-
|
37 |
-
# Clean up any remaining LaTeX-specific syntax
|
38 |
-
md_text = cleanup_latex(md_text)
|
39 |
-
|
40 |
-
logger.info("LaTeX to Markdown conversion completed")
|
41 |
-
return md_text
|
42 |
-
|
43 |
-
def convert_latex_tables(latex_text):
|
44 |
-
"""Convert LaTeX tables to Markdown tables."""
|
45 |
-
result = latex_text
|
46 |
-
|
47 |
-
# Detect and convert tabular environments
|
48 |
-
tabular_pattern = r'\\begin\{(tabular|table)\}(.*?)\\end\{(tabular|table)\}'
|
49 |
-
|
50 |
-
def replace_table(match):
|
51 |
-
table_content = match.group(2)
|
52 |
-
|
53 |
-
# Extract rows
|
54 |
-
rows = re.split(r'\\\\', table_content)
|
55 |
-
md_rows = []
|
56 |
-
|
57 |
-
# Create header separator after first row
|
58 |
-
if rows:
|
59 |
-
first_row = rows[0]
|
60 |
-
# Count columns based on & separators
|
61 |
-
col_count = first_row.count('&') + 1
|
62 |
-
|
63 |
-
# Process rows
|
64 |
-
for i, row in enumerate(rows):
|
65 |
-
# Skip empty rows
|
66 |
-
if not row.strip():
|
67 |
-
continue
|
68 |
-
|
69 |
-
# Split by & to get cells
|
70 |
-
cells = row.split('&')
|
71 |
-
# Clean cell content
|
72 |
-
cells = [cell.strip().replace('\\hline', '') for cell in cells]
|
73 |
-
|
74 |
-
# Join cells with | for Markdown table format
|
75 |
-
md_row = '| ' + ' | '.join(cells) + ' |'
|
76 |
-
md_rows.append(md_row)
|
77 |
-
|
78 |
-
# Add header separator after first row
|
79 |
-
if i == 0:
|
80 |
-
md_rows.append('| ' + ' | '.join(['---'] * col_count) + ' |')
|
81 |
-
|
82 |
-
return '\n'.join(md_rows)
|
83 |
-
|
84 |
-
# Replace all tabular environments
|
85 |
-
result = re.sub(tabular_pattern, replace_table, result, flags=re.DOTALL)
|
86 |
-
return result
|
87 |
-
|
88 |
-
def convert_math_environments(latex_text):
|
89 |
-
"""Convert LaTeX math environments to Markdown math syntax."""
|
90 |
-
result = latex_text
|
91 |
-
|
92 |
-
# Convert equation environments to $$ ... $$ format
|
93 |
-
result = re.sub(r'\\begin\{equation\}(.*?)\\end\{equation\}', r'$$\1$$', result, flags=re.DOTALL)
|
94 |
-
result = re.sub(r'\\begin\{align\}(.*?)\\end\{align\}', r'$$\1$$', result, flags=re.DOTALL)
|
95 |
-
result = re.sub(r'\\begin\{eqnarray\}(.*?)\\end\{eqnarray\}', r'$$\1$$', result, flags=re.DOTALL)
|
96 |
-
|
97 |
-
# Convert inline math $ ... $ (if not already in right format)
|
98 |
-
result = re.sub(r'\\(\(|\))', '$', result)
|
99 |
-
|
100 |
-
# Handle standalone math expressions
|
101 |
-
result = re.sub(r'\\begin\{math\}(.*?)\\end\{math\}', r'$\1$', result, flags=re.DOTALL)
|
102 |
-
|
103 |
-
return result
|
104 |
-
|
105 |
-
def convert_formatting_commands(latex_text):
|
106 |
-
"""Convert LaTeX formatting commands to Markdown syntax."""
|
107 |
-
result = latex_text
|
108 |
-
|
109 |
-
# Bold: \textbf{text} -> **text**
|
110 |
-
result = re.sub(r'\\textbf\{([^}]*)\}', r'**\1**', result)
|
111 |
-
result = re.sub(r'\\bf\{([^}]*)\}', r'**\1**', result)
|
112 |
-
|
113 |
-
# Italic: \textit{text} -> *text*
|
114 |
-
result = re.sub(r'\\textit\{([^}]*)\}', r'*\1*', result)
|
115 |
-
result = re.sub(r'\\it\{([^}]*)\}', r'*\1*', result)
|
116 |
-
result = re.sub(r'\\emph\{([^}]*)\}', r'*\1*', result)
|
117 |
-
|
118 |
-
# Underline: don't have direct equivalent in MD, use emphasis
|
119 |
-
result = re.sub(r'\\underline\{([^}]*)\}', r'_\1_', result)
|
120 |
-
|
121 |
-
# Section headings
|
122 |
-
result = re.sub(r'\\section\{([^}]*)\}', r'## \1', result)
|
123 |
-
result = re.sub(r'\\subsection\{([^}]*)\}', r'### \1', result)
|
124 |
-
result = re.sub(r'\\subsubsection\{([^}]*)\}', r'#### \1', result)
|
125 |
-
|
126 |
-
# Remove \title command
|
127 |
-
result = re.sub(r'\\title\{([^}]*)\}', r'# \1', result)
|
128 |
-
|
129 |
-
return result
|
130 |
-
|
131 |
-
def convert_latex_lists(latex_text):
|
132 |
-
"""Convert LaTeX lists to Markdown lists."""
|
133 |
-
result = latex_text
|
134 |
-
|
135 |
-
# Handle itemize (unordered lists)
|
136 |
-
itemize_pattern = r'\\begin\{itemize\}(.*?)\\end\{itemize\}'
|
137 |
-
|
138 |
-
def replace_itemize(match):
|
139 |
-
list_content = match.group(1)
|
140 |
-
items = re.findall(r'\\item\s+(.*?)(?=\\item|$)', list_content, re.DOTALL)
|
141 |
-
return '\n' + '\n'.join([f'- {item.strip()}' for item in items]) + '\n'
|
142 |
-
|
143 |
-
result = re.sub(itemize_pattern, replace_itemize, result, flags=re.DOTALL)
|
144 |
-
|
145 |
-
# Handle enumerate (ordered lists)
|
146 |
-
enumerate_pattern = r'\\begin\{enumerate\}(.*?)\\end\{enumerate\}'
|
147 |
-
|
148 |
-
def replace_enumerate(match):
|
149 |
-
list_content = match.group(1)
|
150 |
-
items = re.findall(r'\\item\s+(.*?)(?=\\item|$)', list_content, re.DOTALL)
|
151 |
-
return '\n' + '\n'.join([f'{i+1}. {item.strip()}' for i, item in enumerate(items)]) + '\n'
|
152 |
-
|
153 |
-
result = re.sub(enumerate_pattern, replace_enumerate, result, flags=re.DOTALL)
|
154 |
-
|
155 |
-
return result
|
156 |
-
|
157 |
-
def cleanup_latex(latex_text):
|
158 |
-
"""Clean up any remaining LaTeX-specific syntax."""
|
159 |
-
result = latex_text
|
160 |
-
|
161 |
-
# Remove LaTeX document structure commands
|
162 |
-
result = re.sub(r'\\begin\{document\}|\\end\{document\}', '', result)
|
163 |
-
result = re.sub(r'\\maketitle', '', result)
|
164 |
-
result = re.sub(r'\\documentclass\{[^}]*\}', '', result)
|
165 |
-
result = re.sub(r'\\usepackage\{[^}]*\}', '', result)
|
166 |
-
|
167 |
-
# Convert special characters
|
168 |
-
latex_special_chars = {
|
169 |
-
r'\&': '&',
|
170 |
-
r'\%': '%',
|
171 |
-
r'\$': '$',
|
172 |
-
r'\#': '#',
|
173 |
-
r'\_': '_',
|
174 |
-
r'\{': '{',
|
175 |
-
r'\}': '}',
|
176 |
-
r'~': ' ',
|
177 |
-
r'\ldots': '...'
|
178 |
-
}
|
179 |
-
|
180 |
-
for latex_char, md_char in latex_special_chars.items():
|
181 |
-
result = result.replace(latex_char, md_char)
|
182 |
-
|
183 |
-
# Fix extra whitespace
|
184 |
-
result = re.sub(r'\n\s*\n\s*\n', '\n\n', result)
|
185 |
-
|
186 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|