Spanicin commited on
Commit
01480ea
·
verified ·
1 Parent(s): 4256658

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -0
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pdf2image import convert_from_bytes, convert_from_path
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cv2
5
+ import pdfplumber
6
+ from transformers import AutoModel, AutoTokenizer
7
+ import io
8
+ import os
9
+ from PyPDF2 import PdfReader, PdfWriter
10
+ from langchain_openai import ChatOpenAI
11
+ from flask import Flask, request, jsonify, send_file
12
+ from flask_cors import CORS
13
+ import threading
14
+ import time
15
+ import uuid
16
+ import tempfile
17
+ import pytesseract
18
+ pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
19
+
20
+ app = Flask(__name__, template_folder='templates')
21
+ CORS(app, resources={r"/*": {"origins": ["http://localhost:*", "https://play.dev.ryzeai.ai", "https://ryze2ui.dev.ryzeai.ai"]}})
22
+
23
+ # Store process status and results
24
+ process_status = {}
25
+ process_results = {}
26
+ app.config['file_path'] = None
27
+
28
+ data_ready = False # Flag to check if extraction is complete
29
+ lock = threading.Lock() # Lock to manage concurrent access
30
+ extracted_texts = {}
31
+
32
+
33
+ ocr_tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
34
+ ocr_model = AutoModel.from_pretrained(
35
+ 'ucaslcl/GOT-OCR2_0', trust_remote_code=True,
36
+ low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True
37
+ ).eval().cuda()
38
+
39
+ API_KEY = "sk-8754"
40
+ BASE_URL = "https://aura.dev.ryzeai.ai"
41
+ llm = ChatOpenAI(temperature=0, model_name="azure/gpt-4o-mini", api_key=API_KEY, base_url=BASE_URL)
42
+
43
+ class DynamicTableExtractor:
44
+ def __init__(self, pdf_bytes: bytes, output_folder: str):
45
+ self.pdf_bytes = pdf_bytes
46
+ self.images = convert_from_bytes(pdf_bytes)
47
+ self.output_folder = output_folder
48
+ os.makedirs(output_folder, exist_ok=True)
49
+
50
+ def detect_lines(self, img_array):
51
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
52
+ thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
53
+
54
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
55
+ horizontal_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel)
56
+
57
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
58
+ vertical_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel)
59
+
60
+ return horizontal_lines, vertical_lines
61
+
62
+ def find_table_boundaries(self, horizontal_lines: np.ndarray, vertical_lines):
63
+ combined = cv2.addWeighted(horizontal_lines, 1, vertical_lines, 1, 0)
64
+ contours, _ = cv2.findContours(combined, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
65
+ table_boundaries = []
66
+ min_table_area = 5000
67
+
68
+ for contour in contours:
69
+ x, y, w, h = cv2.boundingRect(contour)
70
+ area = w * h
71
+ if area > min_table_area:
72
+ padding = 5
73
+ table_boundaries.append((
74
+ max(0, x - padding),
75
+ max(0, y - padding),
76
+ x + w + padding,
77
+ y + h + padding
78
+ ))
79
+ return table_boundaries
80
+
81
+ def detect_tables_by_text_alignment(self, img):
82
+ text_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
83
+ rows = {}
84
+
85
+ for i, text in enumerate(text_data['text']):
86
+ if text.strip():
87
+ y = text_data['top'][i]
88
+ if y not in rows:
89
+ rows[y] = []
90
+ rows[y].append({
91
+ 'text': text,
92
+ 'left': text_data['left'][i],
93
+ 'width': text_data['width'][i],
94
+ 'height': text_data['height'][i],
95
+ 'conf': text_data['conf'][i]
96
+ })
97
+
98
+ table_regions = []
99
+ current_region = None
100
+ last_y = None
101
+
102
+ for y, row_texts in sorted(rows.items()):
103
+ is_tabular = (
104
+ len(row_texts) >= 3 and
105
+ any(t['text'].replace('.', '').replace('-', '').replace('$', '').isdigit()
106
+ for t in row_texts)
107
+ )
108
+ if is_tabular:
109
+ if current_region and last_y and y - last_y > 50:
110
+ current_region['bottom'] = last_y + row_texts[0]['height']
111
+ table_regions.append(current_region)
112
+ current_region = None
113
+ if current_region is None:
114
+ current_region = {
115
+ 'top': y,
116
+ 'left': min(t['left'] for t in row_texts),
117
+ 'right': max(t['left'] + t['width'] for t in row_texts)
118
+ }
119
+ else:
120
+ current_region['right'] = max(
121
+ current_region['right'],
122
+ max(t['left'] + t['width'] for t in row_texts)
123
+ )
124
+ last_y = y
125
+ elif current_region is not None:
126
+ current_region['bottom'] = y
127
+ table_regions.append(current_region)
128
+ current_region = None
129
+
130
+ if current_region:
131
+ current_region['bottom'] = last_y + 20
132
+ table_regions.append(current_region)
133
+
134
+ return table_regions
135
+
136
+ def merge_boundaries(self, boundaries):
137
+ if not boundaries:
138
+ return []
139
+
140
+ def overlap_or_nearby(b1, b2, threshold=20):
141
+ return not (b1[2] + threshold < b2[0] or b2[2] + threshold < b1[0] or
142
+ b1[3] + threshold < b2[1] or b2[3] + threshold < b1[1])
143
+
144
+ merged = []
145
+ boundaries = sorted(boundaries, key=lambda x: (x[1], x[0]))
146
+ current = list(boundaries[0])
147
+
148
+ for next_bound in boundaries[1:]:
149
+ if overlap_or_nearby(current, next_bound):
150
+ current[0] = min(current[0], next_bound[0])
151
+ current[1] = min(current[1], next_bound[1])
152
+ current[2] = max(current[2], next_bound[2])
153
+ current[3] = max(current[3], next_bound[3])
154
+ else:
155
+ merged.append(tuple(current))
156
+ current = list(next_bound)
157
+
158
+ merged.append(tuple(current))
159
+ return merged
160
+
161
+ def remove_tables_from_image(self, img, table_boundaries):
162
+ img_array = np.array(img)
163
+
164
+ for x1, y1, x2, y2 in table_boundaries:
165
+ img_array[y1:y2, x1:x2] = 255 # Fill table area with white
166
+
167
+ return Image.fromarray(img_array)
168
+
169
+ def extract_tables(self) -> None:
170
+ for page_num, page_img in enumerate(self.images, start=1):
171
+ img_array = np.array(page_img)
172
+ horizontal_lines, vertical_lines = self.detect_lines(img_array)
173
+
174
+ line_based_boundaries = self.find_table_boundaries(horizontal_lines, vertical_lines)
175
+ text_based_regions = self.detect_tables_by_text_alignment(page_img)
176
+ text_based_boundaries = [
177
+ (r['left'], r['top'], r['right'], r['bottom'])
178
+ for r in text_based_regions
179
+ ]
180
+
181
+ all_boundaries = self.merge_boundaries(line_based_boundaries + text_based_boundaries)
182
+ cleaned_image = self.remove_tables_from_image(page_img, all_boundaries)
183
+ cleaned_output_path = os.path.join(self.output_folder, f'cleaned_page{page_num}.png')
184
+ cleaned_image.save(cleaned_output_path)
185
+
186
+ table_count = 0
187
+
188
+ for bounds in all_boundaries:
189
+ table_region = page_img.crop(bounds)
190
+ gray_table = cv2.cvtColor(np.array(table_region), cv2.COLOR_RGB2GRAY)
191
+ text = pytesseract.image_to_string(gray_table).strip()
192
+
193
+ if text:
194
+ table_count += 1
195
+ output_path = os.path.join(self.output_folder, f'page{page_num}_table{table_count}.png')
196
+ table_region.save(output_path)
197
+
198
+ def categorize_pdf_pages(pdf_path):
199
+ page_categories = {}
200
+ with pdfplumber.open(pdf_path) as pdf:
201
+ for page_number, page in enumerate(pdf.pages):
202
+ text = page.extract_text()
203
+ tables = page.extract_tables()
204
+ page_categories[page_number] = "text & table" if tables and text else "only table" if tables else "only text" if text else "empty"
205
+ return page_categories
206
+
207
+ def extract_text_from_image(image_path):
208
+ return ocr_model.chat(ocr_tokenizer, image_path, ocr_type='ocr')
209
+
210
+ def save_text_pages_as_images(pdf_path, categorized_pages, output_dir="output_images"):
211
+ os.makedirs(output_dir, exist_ok=True)
212
+ text_only_pages = [page_num for page_num, category in categorized_pages.items() if category == "only text"]
213
+ extracted_texts = {}
214
+ images = convert_from_path(pdf_path, dpi=300)
215
+ for page_num in text_only_pages:
216
+ image_path = f"{output_dir}/page_{page_num+1}.png"
217
+ images[page_num].save(image_path, 'PNG')
218
+ extracted_texts[page_num + 1] = extract_text_from_image(image_path)
219
+ return extracted_texts
220
+
221
+ def extract_text_from_table_pages(pdf_path, categorized_pages, output_folder="extracted_tables"):
222
+ extracted_texts = {}
223
+ table_pages = [page_num for page_num, category in categorized_pages.items() if category in ["only table", "text & table"]]
224
+ with open(pdf_path, "rb") as f:
225
+ pdf_reader = PdfReader(f)
226
+ for page_num in table_pages:
227
+ pdf_writer = PdfWriter()
228
+ pdf_writer.add_page(pdf_reader.pages[page_num])
229
+ pdf_bytes_io = io.BytesIO()
230
+ pdf_writer.write(pdf_bytes_io)
231
+ pdf_bytes = pdf_bytes_io.getvalue()
232
+ extractor = DynamicTableExtractor(pdf_bytes, output_folder)
233
+ extractor.extract_tables()
234
+ saved_images = sorted(os.listdir(output_folder))
235
+ page_images = [img for img in saved_images if img.endswith('.png')]
236
+ page_texts = [extract_text_from_image(os.path.join(output_folder, img)) for img in page_images]
237
+ if page_texts:
238
+ extracted_texts[page_num] = "\n".join(page_texts)
239
+ return extracted_texts
240
+
241
+
242
+ # @app.route('/upload', methods=['POST'])
243
+ # def extract_from_pdf():
244
+ # global extracted_texts
245
+ # if 'file' not in request.files:
246
+ # return jsonify({'error': 'No file provided'}), 400
247
+ # file = request.files['file']
248
+ # pdf_path = os.path.join("uploads", file.filename)
249
+ # os.makedirs("uploads", exist_ok=True)
250
+ # file.save(pdf_path)
251
+ # categorized_pages = categorize_pdf_pages(pdf_path)
252
+ # extracted_texts = save_text_pages_as_images(pdf_path, categorized_pages)
253
+ # table_texts = extract_text_from_table_pages(pdf_path, categorized_pages)
254
+ # extracted_texts.update(table_texts)
255
+ # return jsonify({'message': 'Extraction completed', 'data': extracted_texts})
256
+
257
+ # @app.route('/query', methods=['POST'])
258
+ # def query_extracted_data():
259
+ # global extracted_texts
260
+ # user_input = request.form['user_question']
261
+ # response = llm.invoke(str(extracted_texts) + " " + user_input)
262
+ # return jsonify({'response': response.content.strip()})
263
+
264
+ def save_extracted_text(text_dict, filepath):
265
+ with open(filepath, "w", encoding="utf-8") as f: # Open in text mode
266
+ for page, text in text_dict.items():
267
+ f.write(f"Page {page}:\n{text}\n\n")
268
+ return filepath
269
+
270
+ def process_pdf(pdf_path, process_id):
271
+ global extracted_texts, data_ready
272
+ with lock:
273
+ data_ready = False # Reset flag when new process starts
274
+
275
+ process_status[process_id] = "in_progress"
276
+ categorized_pages = categorize_pdf_pages(pdf_path)
277
+ extracted_texts = save_text_pages_as_images(pdf_path, categorized_pages)
278
+ table_texts = extract_text_from_table_pages(pdf_path, categorized_pages)
279
+ extracted_texts.update(table_texts)
280
+ temp_file_path = tempfile.mktemp(suffix='.txt')
281
+ filepath = save_extracted_text(extracted_texts, temp_file_path) # Save extracted text to file
282
+ app.config['file_path'] = filepath
283
+ process_status[process_id] = "completed"
284
+ process_results[process_id] = {
285
+ "response": extracted_texts,
286
+ }
287
+
288
+ with lock:
289
+ data_ready = True # Mark extraction as complete
290
+
291
+ @app.route('/upload', methods=['POST'])
292
+ def upload_pdf():
293
+
294
+ global extracted_texts, data_ready
295
+
296
+ if 'file' not in request.files:
297
+ return jsonify({'error': 'No file provided'}), 400
298
+
299
+ file = request.files['file']
300
+ pdf_path = os.path.join("uploads", file.filename)
301
+ os.makedirs("uploads", exist_ok=True)
302
+ file.save(pdf_path)
303
+ process_id = str(uuid.uuid4())
304
+ thread = threading.Thread(target=process_pdf, args=(pdf_path, process_id))
305
+ thread.start() # Start extraction in a separate thread
306
+
307
+ return jsonify({'message': 'File uploaded, extraction in progress', "process_id": process_id})
308
+
309
+ @app.route('/status', methods=['GET'])
310
+ def check_task_status():
311
+ process_id = request.args.get('process_id', None)
312
+ if process_id not in process_status:
313
+ return jsonify({"error": "Invalid process ID"}), 400
314
+
315
+ status = process_status[process_id]
316
+ if status == "completed":
317
+ result = process_results[process_id]
318
+ response = result["response"]
319
+
320
+ return jsonify({
321
+ "status": "completed",
322
+ "response": response,
323
+ "url": f"/download?file_path={app.config['file_path']}"
324
+ }), 200
325
+ elif status == "in_progress":
326
+ return jsonify({"status": "in_progress"}), 200
327
+ elif status == "error":
328
+ return jsonify({"status": "error", "error": process_results[process_id]["error"]}), 500
329
+
330
+ @app.route('/query', methods=['POST'])
331
+ def query_extracted_data():
332
+ process_id = request.args.get('process_id')
333
+ result = process_results[process_id]
334
+ text = result["response"]
335
+ user_input = request.form['user_question']
336
+ response = llm.invoke(str(text) + " " + user_input)
337
+
338
+ return jsonify({'response': response.content.strip()})
339
+
340
+ @app.route("/download")
341
+ def download_file():
342
+ file_path = app.config.get('file_path')
343
+ if file_path:
344
+ return send_file(file_path, as_attachment=True)
345
+ else:
346
+ return jsonify({"message": "File path is missing."}), 404
347
+
348
+
349
+
350
+
351
+ if __name__ == '__main__':
352
+ app.run(debug=False)
353
+ # Start Ngrok in a separate thread
354
+ # def start_ngrok():
355
+ # public_url = ngrok.connect(8000)
356
+ # print(f"Ngrok public URL: {public_url}")
357
+
358
+ # ngrok_thread = Thread(target=start_ngrok)
359
+ # ngrok_thread.start()
360
+
361
+ # # Run Flask app
362
+ # app.run(port=8000)