Spaces:
Running
Running
File size: 3,438 Bytes
19736cf 025580f 104c39e a7de18e 104c39e 19736cf c623da2 19736cf c623da2 19736cf 104c39e deb409e a7de18e 025580f 104c39e c623da2 025580f 104c39e 14299e0 a92b56c 104c39e 14299e0 104c39e c623da2 a7de18e deb409e a7de18e 19736cf 104c39e c4cf574 19736cf 025580f 104c39e c4cf574 19736cf 104c39e 025580f 104c39e c623da2 025580f 4639dba 104c39e 4639dba 104c39e a7de18e deb409e 4639dba deb409e 17f2d95 a92b56c 17f2d95 104c39e a7de18e a92b56c a7de18e 4639dba 104c39e a7de18e 104c39e e08cb5e 4639dba 0ab07a8 4ee3a20 a7de18e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import torch
import json
import csv
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F
from save_results import save_results_to_repo # Import the save function
# Paths
MODEL_PATH = "./distilbert_spam_model"
# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
model.save_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokenizer.save_pretrained(MODEL_PATH)
print(f"✅ Model saved at {MODEL_PATH}.")
else:
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
# 🔹 Ensure model is in evaluation mode
model.eval()
# OCR Functions (No changes here)
def ocr_with_paddle(img):
ocr = PaddleOCR(lang='en', use_angle_cls=True)
result = ocr.ocr(img)
return ' '.join([item[1][0] for item in result[0]])
def ocr_with_keras(img):
pipeline = keras_ocr.pipeline.Pipeline()
images = [keras_ocr.tools.read(img)]
predictions = pipeline.recognize(images)
return ' '.join([text for text, _ in predictions[0]])
def ocr_with_easy(img):
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
reader = easyocr.Reader(['en'])
results = reader.readtext(gray_image, detail=0)
return ' '.join(results)
# OCR & Classification Function
def generate_ocr(method, img):
if img is None:
raise gr.Error("Please upload an image!")
# Convert PIL Image to OpenCV format
img = np.array(img)
# Select OCR method
if method == "PaddleOCR":
text_output = ocr_with_paddle(img)
elif method == "EasyOCR":
text_output = ocr_with_easy(img)
else: # KerasOCR
text_output = ocr_with_keras(img)
# Preprocess text properly
text_output = text_output.strip()
if len(text_output) == 0:
return "No text detected!", "Cannot classify"
# Tokenize text
inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
spam_prob = probs[0][1].item() # Probability of Spam
# Adjust classification based on threshold (better than argmax)
label = "Spam" if spam_prob > 0.5 else "Not Spam"
# Save results using external function
save_results_to_repo(text_output, label)
return text_output, label
# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")
demo = gr.Interface(
generate_ocr,
inputs=[method_input, image_input],
outputs=[output_text, output_label],
title="OCR Spam Classifier",
description="Upload an image, extract text, and classify it as Spam or Not Spam.",
theme="compact",
)
# Launch App
if __name__ == "__main__":
demo.launch()
|