File size: 3,240 Bytes
19736cf
025580f
104c39e
 
 
19736cf
c623da2
19736cf
c623da2
19736cf
104c39e
a92b56c
f9afb31
025580f
104c39e
c623da2
104c39e
 
025580f
104c39e
 
 
a92b56c
104c39e
 
 
 
 
 
 
c623da2
104c39e
19736cf
 
 
104c39e
c4cf574
19736cf
025580f
 
 
104c39e
c4cf574
19736cf
104c39e
025580f
104c39e
c623da2
025580f
104c39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a92b56c
 
104c39e
 
a92b56c
 
 
 
 
104c39e
f9afb31
 
104c39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19736cf
025580f
104c39e
c623da2
104c39e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import json
import csv
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F  # Added for softmax
from save_results import save_results_to_repo  # Import the new save function

# Paths
MODEL_PATH = "./distilbert_spam_model"
RESULTS_JSON = "ocr_results.json"
RESULTS_CSV = "ocr_results.csv"

# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
    print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    model.save_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    tokenizer.save_pretrained(MODEL_PATH)
    print(f"✅ Model saved at {MODEL_PATH}.")
else:
    model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)

# Load OCR Methods
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    return ' '.join([item[1][0] for item in result[0]])

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    return ' '.join([text for text, _ in predictions[0]])

def ocr_with_easy(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reader = easyocr.Reader(['en'])
    results = reader.readtext(gray_image, detail=0)
    return ' '.join(results)

# OCR Function
def generate_ocr(method, img):
    if img is None:
        raise gr.Error("Please upload an image!")

    # Convert PIL Image to OpenCV format
    img = np.array(img)

    # Select OCR method
    if method == "PaddleOCR":
        text_output = ocr_with_paddle(img)
    elif method == "EasyOCR":
        text_output = ocr_with_easy(img)
    else:  # KerasOCR
        text_output = ocr_with_keras(img)

    # Classify Text as Spam or Not Spam
    inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=1)  # Convert logits to probabilities
        prediction = torch.argmax(probs, dim=1).item()

    label_map = {0: "Not Spam", 1: "Spam"}
    label = label_map[prediction]

    # Save results using the external save function
    save_results_to_repo(text_output, label)

    return text_output, label

# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")

demo = gr.Interface(
    generate_ocr,
    inputs=[method_input, image_input],
    outputs=[output_text, output_label],
    title="OCR Spam Classifier",
    description="Upload an image, extract text, and classify it as Spam or Not Spam.",
    theme="compact",
)

# Launch App
if __name__ == "__main__":
    demo.launch()