ocr-llm-test / app.py
winamnd's picture
Update app.py
a30c719 verified
raw
history blame
3.81 kB
import gradio as gr
import torch
import json
import csv
import os
import cv2
import numpy as np
import pandas as pd
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F # Added for softmax
# Paths
MODEL_PATH = "./distilbert_spam_model"
RESULTS_JSON = "ocr_results.json"
RESULTS_CSV = "ocr_results.csv"
# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Load OCR Methods
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# OCR Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
# Classify Text as Spam or Not Spam
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
prediction = torch.argmax(probs, dim=1).item()
label_map = {0: "Not Spam", 1: "Spam"}
label = label_map[prediction]
# Save results
save_results(text_output, label)
return text_output, label
# Save extracted text to JSON & CSV
def save_results(text, label):
data = {"text": text, "label": label}
# Save to JSON
if not os.path.exists(RESULTS_JSON):
with open(RESULTS_JSON, "w") as f:
json.dump([], f)
with open(RESULTS_JSON, "r+") as f:
content = json.load(f)
content.append(data)
f.seek(0)
json.dump(content, f, indent=4)
# Save to CSV
file_exists = os.path.exists(RESULTS_CSV)
with open(RESULTS_CSV, "a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["text", "label"])
if not file_exists:
writer.writeheader()
writer.writerow(data)
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
demo = gr.Interface(
generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text, and classify it as Spam or Not Spam.",
theme="compact",
)
# Launch App
if __name__ == "__main__":
demo.launch()