feras-vbrl commited on
Commit
6a7a825
·
verified ·
1 Parent(s): b4e5504

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +393 -0
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import time
5
+ import logging
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from urllib.parse import urlparse
9
+ import requests
10
+ from PIL import Image
11
+ import fitz # PyMuPDF for PDF processing
12
+ from vllm import LLM, SamplingParams
13
+ from docling_core.types.doc import DoclingDocument
14
+ from docling_core.types.doc.document import DocTagsDocument
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+ logger = logging.getLogger(__name__)
19
+ logger.info("SmolDocling OCR App starting up...")
20
+
21
+ # Set up cache directory
22
+ CACHE_DIR = os.environ.get("CACHE_DIR", "/tmp/smoldocling_cache")
23
+ os.makedirs(CACHE_DIR, exist_ok=True)
24
+ logger.info(f"Cache directory set to: {CACHE_DIR}")
25
+
26
+ # Custom DocumentConverter class that uses vLLM for fast inference
27
+ class VLLMDocumentConverter:
28
+ def __init__(self, model_name="ds4sd/SmolDocling-256M-preview"):
29
+ """
30
+ Initialize the converter with vLLM for fast inference
31
+
32
+ Args:
33
+ model_name: The name of the model to use
34
+ """
35
+ logger.info("Loading SmolDocling model with vLLM...")
36
+ try:
37
+ # Initialize vLLM
38
+ self.model_path = model_name
39
+ self.llm = LLM(model=self.model_path, limit_mm_per_prompt={"image": 1})
40
+ self.sampling_params = SamplingParams(
41
+ temperature=0.0,
42
+ max_tokens=8192
43
+ )
44
+ logger.info("Model loaded successfully with vLLM")
45
+ except Exception as e:
46
+ logger.error(f"Error loading model: {str(e)}")
47
+ raise
48
+
49
+ def load_image_from_path(self, file_path):
50
+ """Load image from a path, handling both images and PDFs"""
51
+ logger.debug(f"Loading from path: {file_path}")
52
+ try:
53
+ # Check if it's a PDF
54
+ if file_path.lower().endswith('.pdf'):
55
+ return self.convert_pdf_to_images(file_path)
56
+ else:
57
+ # It's an image
58
+ pil_image = Image.open(file_path).convert("RGB")
59
+ logger.debug(f"Image loaded successfully: {pil_image.size}")
60
+ return [pil_image] # Return as a list for consistency
61
+ except Exception as e:
62
+ logger.error(f"Error loading file: {str(e)}")
63
+ raise
64
+
65
+ def convert_pdf_to_images(self, pdf_path):
66
+ """Convert PDF to a list of images"""
67
+ logger.debug(f"Converting PDF to images: {pdf_path}")
68
+ try:
69
+ images = []
70
+ with fitz.open(pdf_path) as doc:
71
+ logger.debug(f"PDF has {len(doc)} pages")
72
+ for page_num, page in enumerate(doc):
73
+ logger.debug(f"Processing page {page_num+1}")
74
+ # Render page to an image with higher resolution
75
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
76
+ img_data = pix.tobytes("png")
77
+ img = Image.open(BytesIO(img_data)).convert("RGB")
78
+ images.append(img)
79
+
80
+ logger.debug(f"Converted {len(images)} pages to images")
81
+ return images
82
+ except Exception as e:
83
+ logger.error(f"Error converting PDF to images: {str(e)}")
84
+ raise
85
+
86
+ def load_image_from_url(self, url):
87
+ """Load image from a URL, handling both images and PDFs"""
88
+ logger.debug(f"Loading from URL: {url}")
89
+ try:
90
+ response = requests.get(url, stream=True, timeout=10)
91
+ response.raise_for_status()
92
+
93
+ # Check if it's a PDF
94
+ content_type = response.headers.get('Content-Type', '').lower()
95
+ if content_type == 'application/pdf' or url.lower().endswith('.pdf'):
96
+ # Save PDF to a temporary file
97
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
98
+ tmp_file.write(response.content)
99
+ tmp_path = tmp_file.name
100
+
101
+ try:
102
+ # Convert PDF to images
103
+ images = self.convert_pdf_to_images(tmp_path)
104
+ return images
105
+ finally:
106
+ # Clean up temporary file
107
+ if os.path.exists(tmp_path):
108
+ os.unlink(tmp_path)
109
+ else:
110
+ # It's an image
111
+ pil_image = Image.open(BytesIO(response.content)).convert("RGB")
112
+ logger.debug(f"Image loaded successfully: {pil_image.size}")
113
+ return [pil_image] # Return as a list for consistency
114
+ except Exception as e:
115
+ logger.error(f"Error loading from URL: {str(e)}")
116
+ raise
117
+
118
+ def process_images(self, images, prompt="Convert page to Docling."):
119
+ """Process images using vLLM and return doctags outputs"""
120
+ logger.debug(f"Processing {len(images)} images with prompt: {prompt}")
121
+
122
+ start_time = time.time()
123
+ all_outputs = []
124
+
125
+ # Create chat template
126
+ chat_template = f"<|im_start|>User:<image>{prompt}<end_of_utterance>\nAssistant:"
127
+
128
+ # Process each image
129
+ for i, image in enumerate(images):
130
+ logger.debug(f"Processing image {i+1} of {len(images)}")
131
+
132
+ # Prepare input for vLLM
133
+ llm_input = {"prompt": chat_template, "multi_modal_data": {"image": image}}
134
+
135
+ # Generate output
136
+ output = self.llm.generate([llm_input], sampling_params=self.sampling_params)[0]
137
+ doctags = output.outputs[0].text
138
+
139
+ all_outputs.append(doctags)
140
+ logger.debug(f"Generated doctags for image {i+1} (length: {len(doctags)})")
141
+
142
+ logger.debug(f"Total processing time: {time.time() - start_time:.2f} seconds")
143
+ return all_outputs
144
+
145
+ def convert_to_markdown(self, images, prompt="Convert page to Docling."):
146
+ """Convert images to markdown using vLLM"""
147
+ logger.debug(f"Converting {len(images)} images to markdown with prompt: {prompt}")
148
+ try:
149
+ # Process images
150
+ all_outputs = self.process_images(images, prompt)
151
+
152
+ # Populate document with all pages
153
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(all_outputs, images)
154
+ # Create a docling document
155
+ doc = DoclingDocument(name="ConvertedDocument")
156
+ doc.load_from_doctags(doctags_doc)
157
+
158
+ # Export as markdown
159
+ markdown_text = doc.export_to_markdown()
160
+ logger.debug(f"Combined markdown text length: {len(markdown_text)}")
161
+
162
+ return doc
163
+ except Exception as e:
164
+ logger.error(f"Error converting to markdown: {str(e)}")
165
+ raise
166
+
167
+ def convert(self, source, prompt="Convert page to Docling.", max_pages=None):
168
+ """
169
+ Convert a PDF/image to markdown
170
+
171
+ Args:
172
+ source: Either a path to a file or a URL
173
+ prompt: The prompt to use for conversion
174
+ max_pages: Maximum number of pages to process
175
+
176
+ Returns:
177
+ An object with a document attribute that has an export_to_markdown method
178
+ """
179
+ logger.debug(f"Converting source: {source}")
180
+ try:
181
+ # Check if source is a URL
182
+ if urlparse(source).scheme != "":
183
+ images = self.load_image_from_url(source)
184
+ else:
185
+ # Check if it's a PDF or image
186
+ images = self.load_image_from_path(source)
187
+
188
+ # Limit the number of pages if specified
189
+ if max_pages and max_pages < len(images):
190
+ logger.debug(f"Limiting processing to {max_pages} pages out of {len(images)}")
191
+ images = images[:max_pages]
192
+
193
+ # Convert to markdown
194
+ doc = self.convert_to_markdown(images, prompt)
195
+
196
+ # Return the document
197
+ return ConversionResult(doc)
198
+ except Exception as e:
199
+ logger.error(f"Error in convert method: {str(e)}")
200
+ raise
201
+
202
+ class ConversionResult:
203
+ """A simple class to mimic the interface of the original DocumentConverter result"""
204
+ def __init__(self, document):
205
+ self.document = document
206
+
207
+ # Custom CSS for better layout
208
+ st.markdown("""
209
+ <style>
210
+ .stFileUploader {
211
+ padding: 1rem;
212
+ }
213
+
214
+ button[data-testid="stFileUploaderButtonPrimary"] {
215
+ background-color: #000660 !important;
216
+ border: none !important;
217
+ color: white !important;
218
+ }
219
+
220
+ .stButton button {
221
+ background-color: #006666;
222
+ border: none !important;
223
+ color: white;
224
+ padding: 0.5rem 2rem !important;
225
+ }
226
+ .stButton button:hover {
227
+ background-color: #008080 !important;
228
+ color: white !important;
229
+ border-color: #008080 !important;
230
+ }
231
+ .upload-text {
232
+ font-size: 1.2rem;
233
+ margin-bottom: 1rem;
234
+ }
235
+ div[data-testid="stFileUploadDropzone"]:hover {
236
+ border-color: #006666 !important;
237
+ background-color: rgba(0, 102, 102, 0.05) !important;
238
+ }
239
+ </style>
240
+ """, unsafe_allow_html=True)
241
+
242
+ def main():
243
+ logger.info("Starting SmolDocling OCR App main function")
244
+
245
+ st.title("PDF to Markdown Converter")
246
+ st.subheader("Using SmolDocling OCR with vLLM")
247
+
248
+ # Add a sidebar for model and processing settings
249
+ st.sidebar.title("Settings")
250
+
251
+ # Model settings
252
+ st.sidebar.subheader("Model Settings")
253
+ model_name = st.sidebar.text_input(
254
+ "Model Name",
255
+ value="ds4sd/SmolDocling-256M-preview",
256
+ help="Enter the name of the model to use for PDF to Markdown conversion"
257
+ )
258
+
259
+ # Processing settings
260
+ st.sidebar.subheader("Processing Settings")
261
+ max_pages = st.sidebar.slider(
262
+ "Maximum Pages to Process",
263
+ 1, 50, 10,
264
+ help="Limit the number of pages to process for large PDFs"
265
+ )
266
+
267
+ st.sidebar.markdown("""
268
+ ### About This App
269
+ This app uses the SmolDocling model with vLLM for fast inference to convert PDFs and images to Markdown.
270
+
271
+ vLLM is a high-performance library for LLM inference that can significantly speed up processing.
272
+ """)
273
+
274
+ # Create a button to reload the model if the model name changes
275
+ reload_model = st.sidebar.button("Reload Model")
276
+
277
+ # Initialize or reload the converter when needed
278
+ if 'converter' not in st.session_state or reload_model:
279
+ try:
280
+ with st.spinner(f"Loading model {model_name}... This may take a while for the first time."):
281
+ logger.debug(f"Creating VLLMDocumentConverter instance with model: {model_name}")
282
+ st.session_state.converter = VLLMDocumentConverter(model_name=model_name)
283
+ logger.debug("Converter successfully created")
284
+ st.sidebar.success(f"Model {model_name} loaded successfully!")
285
+ except Exception as e:
286
+ error_msg = str(e)
287
+ logger.error(f"Error creating converter: {error_msg}")
288
+ st.error(f"Error creating converter: {error_msg}")
289
+
290
+ if 'converter' not in st.session_state:
291
+ st.stop()
292
+
293
+ # Main upload area
294
+ uploaded_file = st.file_uploader(
295
+ "Upload your PDF or image file",
296
+ type=['pdf', 'png', 'jpg', 'jpeg'],
297
+ key='file_uploader',
298
+ help="Drag and drop or click to select a file (max 200MB)"
299
+ )
300
+
301
+ # URL input area with spacing
302
+ st.markdown("<br>", unsafe_allow_html=True)
303
+ url = st.text_input("Or enter a PDF/image URL")
304
+
305
+ # Prompt input
306
+ prompt = st.text_input("Conversion prompt (optional)", value="Convert page to Docling.")
307
+
308
+ # Unified convert button
309
+ convert_clicked = st.button("Convert to Markdown", type="primary")
310
+
311
+ # Process either uploaded file or URL
312
+ if convert_clicked:
313
+ if uploaded_file is not None:
314
+ try:
315
+ with st.spinner('Converting file...'):
316
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{uploaded_file.name.split(".")[-1]}') as tmp_file:
317
+ tmp_file.write(uploaded_file.getvalue())
318
+ tmp_path = tmp_file.name
319
+ logger.debug(f"Temporary file created at: {tmp_path}")
320
+
321
+ try:
322
+ logger.debug(f"Converting file: {uploaded_file.name}")
323
+ # Convert the file
324
+ result = st.session_state.converter.convert(
325
+ tmp_path,
326
+ prompt=prompt,
327
+ max_pages=max_pages
328
+ )
329
+ markdown_text = result.document.export_to_markdown()
330
+ logger.debug(f"Markdown text length: {len(markdown_text)}")
331
+
332
+ output_filename = os.path.splitext(uploaded_file.name)[0] + '.md'
333
+
334
+ st.success("Conversion completed!")
335
+ st.download_button(
336
+ label="Download Markdown file",
337
+ data=markdown_text,
338
+ file_name=output_filename,
339
+ mime="text/markdown"
340
+ )
341
+
342
+ # Display the markdown
343
+ st.subheader("Preview:")
344
+ st.markdown(markdown_text)
345
+
346
+ except Exception as e:
347
+ logger.error(f"Error converting file: {str(e)}")
348
+ st.error(f"Error converting file: {str(e)}")
349
+
350
+ finally:
351
+ if os.path.exists(tmp_path):
352
+ os.unlink(tmp_path)
353
+ logger.debug("Temporary file deleted")
354
+
355
+ except Exception as e:
356
+ logger.error(f"Error processing file: {str(e)}")
357
+ st.error(f"Error processing file: {str(e)}")
358
+
359
+ elif url:
360
+ try:
361
+ with st.spinner('Converting from URL...'):
362
+ logger.debug(f"Converting from URL: {url}")
363
+ # Convert from URL
364
+ result = st.session_state.converter.convert(
365
+ url,
366
+ prompt=prompt,
367
+ max_pages=max_pages
368
+ )
369
+ markdown_text = result.document.export_to_markdown()
370
+ logger.debug(f"Markdown text length: {len(markdown_text)}")
371
+
372
+ output_filename = url.split('/')[-1].split('.')[0] + '.md'
373
+
374
+ st.success("Conversion completed!")
375
+ st.download_button(
376
+ label="Download Markdown file",
377
+ data=markdown_text,
378
+ file_name=output_filename,
379
+ mime="text/markdown"
380
+ )
381
+
382
+ # Display the markdown
383
+ st.subheader("Preview:")
384
+ st.markdown(markdown_text)
385
+
386
+ except Exception as e:
387
+ logger.error(f"Error converting from URL: {str(e)}")
388
+ st.error(f"Error converting from URL: {str(e)}")
389
+ else:
390
+ st.warning("Please upload a file or enter a URL first")
391
+
392
+ if __name__ == "__main__":
393
+ main()