pdfextraction / app.py
Spanicin's picture
Update app.py
bf2c469 verified
import os
from pdf2image import convert_from_bytes, convert_from_path
from PIL import Image
import numpy as np
import cv2
import pdfplumber
from transformers import AutoModel, AutoTokenizer
import io
import os
from PyPDF2 import PdfReader, PdfWriter
from langchain_openai import ChatOpenAI
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import threading
import time
import uuid
import tempfile
import pytesseract
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
app = Flask(__name__, template_folder='templates')
CORS(app, resources={r"/*": {"origins": ["http://localhost:*", "https://play.dev.ryzeai.ai", "https://ryze2ui.dev.ryzeai.ai"]}})
# Store process status and results
process_status = {}
process_results = {}
app.config['file_path'] = None
TEMP_DIR = tempfile.mkdtemp()
data_ready = False # Flag to check if extraction is complete
lock = threading.Lock() # Lock to manage concurrent access
extracted_texts = {}
os.environ["HF_HOME"] = os.path.join(TEMP_DIR, "cache") #"/app/cache"
ocr_tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
ocr_model = AutoModel.from_pretrained(
'ucaslcl/GOT-OCR2_0', trust_remote_code=True,
low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True
).eval().cuda()
API_KEY = "sk-8754"
BASE_URL = "https://aura.dev.ryzeai.ai"
llm = ChatOpenAI(temperature=0, model_name="azure/gpt-4o-mini", api_key=API_KEY, base_url=BASE_URL)
class DynamicTableExtractor:
def __init__(self, pdf_bytes: bytes, output_folder: str):
self.pdf_bytes = pdf_bytes
self.images = convert_from_bytes(pdf_bytes)
self.output_folder = os.path.join(TEMP_DIR, output_folder)
os.makedirs(self.output_folder, exist_ok=True)
def detect_lines(self, img_array):
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
horizontal_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel)
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
vertical_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel)
return horizontal_lines, vertical_lines
def find_table_boundaries(self, horizontal_lines: np.ndarray, vertical_lines):
combined = cv2.addWeighted(horizontal_lines, 1, vertical_lines, 1, 0)
contours, _ = cv2.findContours(combined, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
table_boundaries = []
min_table_area = 5000
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
area = w * h
if area > min_table_area:
padding = 5
table_boundaries.append((
max(0, x - padding),
max(0, y - padding),
x + w + padding,
y + h + padding
))
return table_boundaries
def detect_tables_by_text_alignment(self, img):
text_data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
rows = {}
for i, text in enumerate(text_data['text']):
if text.strip():
y = text_data['top'][i]
if y not in rows:
rows[y] = []
rows[y].append({
'text': text,
'left': text_data['left'][i],
'width': text_data['width'][i],
'height': text_data['height'][i],
'conf': text_data['conf'][i]
})
table_regions = []
current_region = None
last_y = None
for y, row_texts in sorted(rows.items()):
is_tabular = (
len(row_texts) >= 3 and
any(t['text'].replace('.', '').replace('-', '').replace('$', '').isdigit()
for t in row_texts)
)
if is_tabular:
if current_region and last_y and y - last_y > 50:
current_region['bottom'] = last_y + row_texts[0]['height']
table_regions.append(current_region)
current_region = None
if current_region is None:
current_region = {
'top': y,
'left': min(t['left'] for t in row_texts),
'right': max(t['left'] + t['width'] for t in row_texts)
}
else:
current_region['right'] = max(
current_region['right'],
max(t['left'] + t['width'] for t in row_texts)
)
last_y = y
elif current_region is not None:
current_region['bottom'] = y
table_regions.append(current_region)
current_region = None
if current_region:
current_region['bottom'] = last_y + 20
table_regions.append(current_region)
return table_regions
def merge_boundaries(self, boundaries):
if not boundaries:
return []
def overlap_or_nearby(b1, b2, threshold=20):
return not (b1[2] + threshold < b2[0] or b2[2] + threshold < b1[0] or
b1[3] + threshold < b2[1] or b2[3] + threshold < b1[1])
merged = []
boundaries = sorted(boundaries, key=lambda x: (x[1], x[0]))
current = list(boundaries[0])
for next_bound in boundaries[1:]:
if overlap_or_nearby(current, next_bound):
current[0] = min(current[0], next_bound[0])
current[1] = min(current[1], next_bound[1])
current[2] = max(current[2], next_bound[2])
current[3] = max(current[3], next_bound[3])
else:
merged.append(tuple(current))
current = list(next_bound)
merged.append(tuple(current))
return merged
def remove_tables_from_image(self, img, table_boundaries):
img_array = np.array(img)
for x1, y1, x2, y2 in table_boundaries:
img_array[y1:y2, x1:x2] = 255 # Fill table area with white
return Image.fromarray(img_array)
def extract_tables(self) -> None:
for page_num, page_img in enumerate(self.images, start=1):
img_array = np.array(page_img)
horizontal_lines, vertical_lines = self.detect_lines(img_array)
line_based_boundaries = self.find_table_boundaries(horizontal_lines, vertical_lines)
text_based_regions = self.detect_tables_by_text_alignment(page_img)
text_based_boundaries = [
(r['left'], r['top'], r['right'], r['bottom'])
for r in text_based_regions
]
all_boundaries = self.merge_boundaries(line_based_boundaries + text_based_boundaries)
cleaned_image = self.remove_tables_from_image(page_img, all_boundaries)
cleaned_output_path = os.path.join(self.output_folder, f'cleaned_page{page_num}.png')
cleaned_image.save(cleaned_output_path)
table_count = 0
for bounds in all_boundaries:
table_region = page_img.crop(bounds)
gray_table = cv2.cvtColor(np.array(table_region), cv2.COLOR_RGB2GRAY)
text = pytesseract.image_to_string(gray_table).strip()
if text:
table_count += 1
output_path = os.path.join(self.output_folder, f'page{page_num}_table{table_count}.png')
table_region.save(output_path)
def categorize_pdf_pages(pdf_path):
page_categories = {}
with pdfplumber.open(pdf_path) as pdf:
for page_number, page in enumerate(pdf.pages):
text = page.extract_text()
tables = page.extract_tables()
page_categories[page_number] = "text & table" if tables and text else "only table" if tables else "only text" if text else "empty"
return page_categories
def extract_text_from_image(image_path):
return ocr_model.chat(ocr_tokenizer, image_path, ocr_type='ocr')
def save_text_pages_as_images(pdf_path, categorized_pages, output_dir="output_images"):
output_dir = os.path.join(TEMP_DIR, output_dir)
os.makedirs(output_dir, exist_ok=True)
text_only_pages = [page_num for page_num, category in categorized_pages.items() if category == "only text"]
extracted_texts = {}
images = convert_from_path(pdf_path, dpi=300)
for page_num in text_only_pages:
image_path = f"{output_dir}/page_{page_num+1}.png"
images[page_num].save(image_path, 'PNG')
extracted_texts[page_num + 1] = extract_text_from_image(image_path)
return extracted_texts
def extract_text_from_table_pages(pdf_path, categorized_pages, output_folder="extracted_tables"):
output_folder = os.path.join(TEMP_DIR, output_folder)
os.makedirs(output_folder, exist_ok=True)
extracted_texts = {}
table_pages = [page_num for page_num, category in categorized_pages.items() if category in ["only table", "text & table"]]
with open(pdf_path, "rb") as f:
pdf_reader = PdfReader(f)
for page_num in table_pages:
pdf_writer = PdfWriter()
pdf_writer.add_page(pdf_reader.pages[page_num])
pdf_bytes_io = io.BytesIO()
pdf_writer.write(pdf_bytes_io)
pdf_bytes = pdf_bytes_io.getvalue()
extractor = DynamicTableExtractor(pdf_bytes, output_folder)
extractor.extract_tables()
saved_images = sorted(os.listdir(output_folder))
page_images = [img for img in saved_images if img.endswith('.png')]
page_texts = [extract_text_from_image(os.path.join(output_folder, img)) for img in page_images]
if page_texts:
extracted_texts[page_num] = "\n".join(page_texts)
return extracted_texts
# @app.route('/upload', methods=['POST'])
# def extract_from_pdf():
# global extracted_texts
# if 'file' not in request.files:
# return jsonify({'error': 'No file provided'}), 400
# file = request.files['file']
# pdf_path = os.path.join("uploads", file.filename)
# os.makedirs("uploads", exist_ok=True)
# file.save(pdf_path)
# categorized_pages = categorize_pdf_pages(pdf_path)
# extracted_texts = save_text_pages_as_images(pdf_path, categorized_pages)
# table_texts = extract_text_from_table_pages(pdf_path, categorized_pages)
# extracted_texts.update(table_texts)
# return jsonify({'message': 'Extraction completed', 'data': extracted_texts})
# @app.route('/query', methods=['POST'])
# def query_extracted_data():
# global extracted_texts
# user_input = request.form['user_question']
# response = llm.invoke(str(extracted_texts) + " " + user_input)
# return jsonify({'response': response.content.strip()})
def save_extracted_text(text_dict, filepath):
with open(filepath, "w", encoding="utf-8") as f: # Open in text mode
for page, text in text_dict.items():
f.write(f"Page {page}:\n{text}\n\n")
return filepath
def process_pdf(pdf_path, process_id):
global extracted_texts, data_ready
with lock:
data_ready = False # Reset flag when new process starts
process_status[process_id] = "in_progress"
categorized_pages = categorize_pdf_pages(pdf_path)
extracted_texts = save_text_pages_as_images(pdf_path, categorized_pages)
table_texts = extract_text_from_table_pages(pdf_path, categorized_pages)
extracted_texts.update(table_texts)
temp_file_path = os.path.join(TEMP_DIR, f"extracted_{process_id}.txt")
# temp_file_path = tempfile.mktemp(suffix='.txt')
filepath = save_extracted_text(extracted_texts, temp_file_path) # Save extracted text to file
app.config['file_path'] = filepath
process_status[process_id] = "completed"
process_results[process_id] = {
"response": extracted_texts,
}
with lock:
data_ready = True # Mark extraction as complete
@app.route('/upload', methods=['POST'])
def upload_pdf():
global extracted_texts, data_ready
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
pdf_path = os.path.join(TEMP_DIR, "uploads", file.filename)
os.makedirs(os.path.dirname(pdf_path), exist_ok=True)
file.save(pdf_path)
process_id = str(uuid.uuid4())
thread = threading.Thread(target=process_pdf, args=(pdf_path, process_id))
thread.start() # Start extraction in a separate thread
return jsonify({'message': 'File uploaded, extraction in progress', "process_id": process_id})
@app.route('/status', methods=['GET'])
def check_task_status():
process_id = request.args.get('process_id', None)
if process_id not in process_status:
return jsonify({"error": "Invalid process ID"}), 400
status = process_status[process_id]
if status == "completed":
result = process_results[process_id]
response = result["response"]
return jsonify({
"status": "completed",
"response": response,
"url": f"/download?file_path={app.config['file_path']}"
}), 200
elif status == "in_progress":
return jsonify({"status": "in_progress"}), 200
elif status == "error":
return jsonify({"status": "error", "error": process_results[process_id]["error"]}), 500
@app.route('/query', methods=['POST'])
def query_extracted_data():
process_id = request.args.get('process_id')
result = process_results[process_id]
text = result["response"]
user_input = request.form['user_question']
llm_instruction = """You are an AI assistant that strictly follows the given data and information without making assumptions, performing calculations, or using algorithms to infer values. Retrieve answers only from explicitly provided data."""
response = llm.invoke(llm_instruction + " " + str(text) + " " + user_input)
# response = llm.invoke(str(text) + " " + user_input)
return jsonify({'response': response.content.strip()})
@app.route("/download")
def download_file():
file_path = app.config.get('file_path')
if file_path:
return send_file(file_path, as_attachment=True)
else:
return jsonify({"message": "File path is missing."}), 404
if __name__ == '__main__':
app.run(debug=False)
# Start Ngrok in a separate thread
# def start_ngrok():
# public_url = ngrok.connect(8000)
# print(f"Ngrok public URL: {public_url}")
# ngrok_thread = Thread(target=start_ngrok)
# ngrok_thread.start()
# # Run Flask app
# app.run(port=8000)