File size: 3,640 Bytes
19736cf
025580f
104c39e
 
19736cf
c623da2
19736cf
c623da2
19736cf
104c39e
deb409e
025580f
104c39e
c623da2
17f2d95
025580f
104c39e
 
14299e0
a92b56c
104c39e
 
 
14299e0
104c39e
 
 
c623da2
17f2d95
deb409e
 
17f2d95
19736cf
 
 
104c39e
c4cf574
19736cf
025580f
 
 
104c39e
c4cf574
19736cf
104c39e
025580f
104c39e
c623da2
025580f
17f2d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4639dba
 
104c39e
4639dba
104c39e
 
 
 
 
 
 
 
 
 
 
 
deb409e
 
4639dba
deb409e
17f2d95
a92b56c
 
17f2d95
104c39e
 
17f2d95
 
a92b56c
4639dba
104c39e
17f2d95
 
104c39e
 
 
e08cb5e
4639dba
 
 
 
 
 
 
 
 
 
 
 
 
0ab07a8
4ee3a20
f56fe40
17f2d95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import json
import os
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch.nn.functional as F

# Paths
MODEL_PATH = "./distilbert_spam_model"
RESULTS_JSON = "results.json"

# Ensure model exists
if not os.path.exists(os.path.join(MODEL_PATH, "pytorch_model.bin")):
    print(f"⚠️ Model not found in {MODEL_PATH}. Downloading from Hugging Face Hub...")
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    model.save_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    tokenizer.save_pretrained(MODEL_PATH)
    print(f"✅ Model saved at {MODEL_PATH}.")
else:
    model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
    tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)

# Ensure model is in evaluation mode
model.eval()

# OCR Functions
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    return ' '.join([item[1][0] for item in result[0]])

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    return ' '.join([text for text, _ in predictions[0]])

def ocr_with_easy(img):
    gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    reader = easyocr.Reader(['en'])
    results = reader.readtext(gray_image, detail=0)
    return ' '.join(results)

# Save results to JSON
def save_to_json(text, label):
    data = {"text": text, "classification": label}
    if os.path.exists(RESULTS_JSON):
        with open(RESULTS_JSON, "r") as file:
            try:
                results = json.load(file)
            except json.JSONDecodeError:
                results = []
    else:
        results = []
    
    results.append(data)
    with open(RESULTS_JSON, "w") as file:
        json.dump(results, file, indent=4)

# OCR & Classification Function
def generate_ocr(method, img):
    if img is None:
        raise gr.Error("Please upload an image!")

    # Convert PIL Image to OpenCV format
    img = np.array(img)

    # Select OCR method
    if method == "PaddleOCR":
        text_output = ocr_with_paddle(img)
    elif method == "EasyOCR":
        text_output = ocr_with_easy(img)
    else:  # KerasOCR
        text_output = ocr_with_keras(img)

    text_output = text_output.strip()
    if len(text_output) == 0:
        return "No text detected!", "Cannot classify"

    # Tokenize text
    inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=1)
        spam_prob = probs[0][1].item()

    label = "Spam" if spam_prob > 0.5 else "Not Spam"

    # Save results to JSON
    save_to_json(text_output, label)

    return text_output, label

# Gradio Interface
image_input = gr.Image()
method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
output_text = gr.Textbox(label="Extracted Text")
output_label = gr.Textbox(label="Spam Classification")

demo = gr.Interface(
    generate_ocr,
    inputs=[method_input, image_input],
    outputs=[output_text, output_label],
    title="OCR Spam Classifier",
    description="Upload an image, extract text, and classify it as Spam or Not Spam.",
    theme="compact",
)

# Launch App
if __name__ == "_main_":
    demo.launch()