ocr-llm-test / app.py
winamnd's picture
Update app.py
f56fe40 verified
raw
history blame
3.64 kB
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
# Paths
MODEL_PATH = "./distilbert_spam_model"
RESULTS_JSON = "results.json"
# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Ensure model is in evaluation mode
model.eval()
# OCR Functions
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# Save results to JSON
def save_to_json(text, label):
data = {"text": text, "classification": label}
if os.path.exists(RESULTS_JSON):
with open(RESULTS_JSON, "r") as file:
try:
results = json.load(file)
except json.JSONDecodeError:
results = []
else:
results = []
results.append(data)
with open(RESULTS_JSON, "w") as file:
json.dump(results, file, indent=4)
# OCR & Classification Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", "Cannot classify"
# Tokenize text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
spam_prob = probs[0][1].item()
label = "Spam" if spam_prob > 0.5 else "Not Spam"
# Save results to JSON
save_to_json(text_output, label)
return text_output, label
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
demo = gr.Interface(
generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text, and classify it as Spam or Not Spam.",
theme="compact",
)
# Launch App
if __name__ == "_main_":
demo.launch()