Spaces:
Running
Running
import gradio as gr | |
import torch | |
import json | |
import os | |
import cv2 | |
import numpy as np | |
import easyocr | |
import keras_ocr | |
from paddleocr import PaddleOCR | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import torch.nn.functional as F | |
# Paths | |
MODEL_PATH = "./distilbert_spam_model" | |
RESULTS_JSON = "results.json" | |
# Ensure model exists | |
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")): | |
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...") | |
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2) | |
model.save_pretrained(MODEL_PATH) | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
tokenizer.save_pretrained(MODEL_PATH) | |
print(f"✅ Model saved at {MODEL_PATH}.") | |
else: | |
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) | |
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) | |
# Ensure model is in evaluation mode | |
model.eval() | |
# OCR Functions | |
def ocr_with_paddle(img): | |
ocr = PaddleOCR(lang='en', use_angle_cls=True) | |
result = ocr.ocr(img) | |
return ' '.join([item[1][0] for item in result[0]]) | |
def ocr_with_keras(img): | |
pipeline = keras_ocr.pipeline.Pipeline() | |
images = [keras_ocr.tools.read(img)] | |
predictions = pipeline.recognize(images) | |
return ' '.join([text for text, _ in predictions[0]]) | |
def ocr_with_easy(img): | |
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
reader = easyocr.Reader(['en']) | |
results = reader.readtext(gray_image, detail=0) | |
return ' '.join(results) | |
# Save results to JSON | |
def save_to_json(text, label): | |
data = {"text": text, "classification": label} | |
if os.path.exists(RESULTS_JSON): | |
with open(RESULTS_JSON, "r") as file: | |
try: | |
results = json.load(file) | |
except json.JSONDecodeError: | |
results = [] | |
else: | |
results = [] | |
results.append(data) | |
with open(RESULTS_JSON, "w") as file: | |
json.dump(results, file, indent=4) | |
# OCR & Classification Function | |
def generate_ocr(method, img): | |
if img is None: | |
raise gr.Error("Please upload an image!") | |
# Convert PIL Image to OpenCV format | |
img = np.array(img) | |
# Select OCR method | |
if method == "PaddleOCR": | |
text_output = ocr_with_paddle(img) | |
elif method == "EasyOCR": | |
text_output = ocr_with_easy(img) | |
else: # KerasOCR | |
text_output = ocr_with_keras(img) | |
text_output = text_output.strip() | |
if len(text_output) == 0: | |
return "No text detected!", "Cannot classify" | |
# Tokenize text | |
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = F.softmax(outputs.logits, dim=1) | |
spam_prob = probs[0][1].item() | |
label = "Spam" if spam_prob > 0.5 else "Not Spam" | |
# Save results to JSON | |
save_to_json(text_output, label) | |
return text_output, label | |
# Gradio Interface | |
image_input = gr.Image() | |
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR") | |
output_text = gr.Textbox(label="Extracted Text") | |
output_label = gr.Textbox(label="Spam Classification") | |
demo = gr.Interface( | |
generate_ocr, | |
inputs=[method_input, image_input], | |
outputs=[output_text, output_label], | |
title="OCR Spam Classifier", | |
description="Upload an image, extract text, and classify it as Spam or Not Spam.", | |
theme="compact", | |
) | |
# Launch App | |
if __name__ == "_main_": | |
demo.launch() |