Spaces:
Running
Running
File size: 3,640 Bytes
19736cf 025580f 104c39e 19736cf c623da2 19736cf c623da2 19736cf 104c39e deb409e 025580f 104c39e c623da2 17f2d95 025580f 104c39e 14299e0 a92b56c 104c39e 14299e0 104c39e c623da2 17f2d95 deb409e 17f2d95 19736cf 104c39e c4cf574 19736cf 025580f 104c39e c4cf574 19736cf 104c39e 025580f 104c39e c623da2 025580f 17f2d95 4639dba 104c39e 4639dba 104c39e deb409e 4639dba deb409e 17f2d95 a92b56c 17f2d95 104c39e 17f2d95 a92b56c 4639dba 104c39e 17f2d95 104c39e e08cb5e 4639dba 0ab07a8 4ee3a20 f56fe40 17f2d95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
# Paths
MODEL_PATH = "./distilbert_spam_model"
RESULTS_JSON = "results.json"
# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# Ensure model is in evaluation mode
model.eval()
# OCR Functions
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# Save results to JSON
def save_to_json(text, label):
data = {"text": text, "classification": label}
if os.path.exists(RESULTS_JSON):
with open(RESULTS_JSON, "r") as file:
try:
results = json.load(file)
except json.JSONDecodeError:
results = []
else:
results = []
results.append(data)
with open(RESULTS_JSON, "w") as file:
json.dump(results, file, indent=4)
# OCR & Classification Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", "Cannot classify"
# Tokenize text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
spam_prob = probs[0][1].item()
label = "Spam" if spam_prob > 0.5 else "Not Spam"
# Save results to JSON
save_to_json(text_output, label)
return text_output, label
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
demo = gr.Interface(
generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text, and classify it as Spam or Not Spam.",
theme="compact",
)
# Launch App
if __name__ == "_main_":
demo.launch() |