ocr-llm-test / app.py
winamnd's picture
Update app.py
c623da2 verified
raw
history blame
2.91 kB
import gradio as gr
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
import os
# Ensure model config exists
MODEL_PATH = "./distilbert_spam_model"
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
print("config.json not found. Generating default configuration...")
config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=2)
config.save_pretrained(MODEL_PATH)
# Load tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
# Define Spam Classification Function
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1).item()
return "Spam" if prediction == 1 else "Not Spam"
# OCR Methods
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
extracted_text = ' '.join([entry[1][0] for entry in result[0]])
return extracted_text
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
extracted_text = ' '.join([text for text, _ in predictions[0]])
return extracted_text
def ocr_with_easy(img):
reader = easyocr.Reader(['en'])
results = reader.readtext(img, detail=0)
return ' '.join(results)
# OCR + Spam Detection
def process_image(ocr_method, image):
if image is None:
return "Error: No image uploaded."
if ocr_method == "PaddleOCR":
extracted_text = ocr_with_paddle(image)
elif ocr_method == "KerasOCR":
extracted_text = ocr_with_keras(image)
elif ocr_method == "EasyOCR":
extracted_text = ocr_with_easy(image)
else:
return "Invalid OCR method."
if not extracted_text.strip():
return "No text detected in the image."
classification = classify_text(extracted_text)
return f"Extracted Text: {extracted_text}\n\nClassification: {classification}"
# Gradio UI
image_input = gr.Image(type="numpy")
ocr_method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="OCR Method")
output_text = gr.Textbox(label="OCR & Classification Result")
interface = gr.Interface(
fn=process_image,
inputs=[ocr_method_input, image_input],
outputs=output_text,
title="OCR + Spam Detection",
description="Upload an image with text, extract the text using OCR, and classify it as Spam or Not Spam using DistilBERT.",
theme="compact"
)
# Launch app
if __name__ == "__main__":
interface.launch()