Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig | |
import cv2 | |
import numpy as np | |
import easyocr | |
import keras_ocr | |
from paddleocr import PaddleOCR | |
import os | |
# Ensure model config exists | |
MODEL_PATH = "./distilbert_spam_model" | |
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")): | |
print("config.json not found. Generating default configuration...") | |
config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=2) | |
config.save_pretrained(MODEL_PATH) | |
# Load tokenizer and model | |
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH) | |
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH) | |
# Define Spam Classification Function | |
def classify_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
prediction = torch.argmax(logits, dim=-1).item() | |
return "Spam" if prediction == 1 else "Not Spam" | |
# OCR Methods | |
def ocr_with_paddle(img): | |
ocr = PaddleOCR(lang='en', use_angle_cls=True) | |
result = ocr.ocr(img) | |
extracted_text = ' '.join([entry[1][0] for entry in result[0]]) | |
return extracted_text | |
def ocr_with_keras(img): | |
pipeline = keras_ocr.pipeline.Pipeline() | |
images = [keras_ocr.tools.read(img)] | |
predictions = pipeline.recognize(images) | |
extracted_text = ' '.join([text for text, _ in predictions[0]]) | |
return extracted_text | |
def ocr_with_easy(img): | |
reader = easyocr.Reader(['en']) | |
results = reader.readtext(img, detail=0) | |
return ' '.join(results) | |
# OCR + Spam Detection | |
def process_image(ocr_method, image): | |
if image is None: | |
return "Error: No image uploaded." | |
if ocr_method == "PaddleOCR": | |
extracted_text = ocr_with_paddle(image) | |
elif ocr_method == "KerasOCR": | |
extracted_text = ocr_with_keras(image) | |
elif ocr_method == "EasyOCR": | |
extracted_text = ocr_with_easy(image) | |
else: | |
return "Invalid OCR method." | |
if not extracted_text.strip(): | |
return "No text detected in the image." | |
classification = classify_text(extracted_text) | |
return f"Extracted Text: {extracted_text}\n\nClassification: {classification}" | |
# Gradio UI | |
image_input = gr.Image(type="numpy") | |
ocr_method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="OCR Method") | |
output_text = gr.Textbox(label="OCR & Classification Result") | |
interface = gr.Interface( | |
fn=process_image, | |
inputs=[ocr_method_input, image_input], | |
outputs=output_text, | |
title="OCR + Spam Detection", | |
description="Upload an image with text, extract the text using OCR, and classify it as Spam or Not Spam using DistilBERT.", | |
theme="compact" | |
) | |
# Launch app | |
if __name__ == "__main__": | |
interface.launch() | |