File size: 2,907 Bytes
19736cf
025580f
c623da2
19736cf
c623da2
19736cf
c623da2
19736cf
c623da2
025580f
c623da2
 
025580f
c623da2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
025580f
c623da2
19736cf
 
 
c623da2
 
c4cf574
19736cf
025580f
 
 
c623da2
 
c4cf574
19736cf
025580f
c623da2
 
025580f
c623da2
 
 
 
c4cf574
c623da2
 
 
 
 
 
 
 
 
 
 
025580f
c623da2
 
a4bd204
c623da2
 
 
 
c4cf574
c623da2
 
 
 
 
 
 
19736cf
025580f
c623da2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
import cv2
import numpy as np
import easyocr
import keras_ocr
from paddleocr import PaddleOCR
import os

# Ensure model config exists
MODEL_PATH = "./distilbert_spam_model"

if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
    print("config.json not found. Generating default configuration...")
    config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=2)
    config.save_pretrained(MODEL_PATH)

# Load tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)

# Define Spam Classification Function
def classify_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1).item()
    return "Spam" if prediction == 1 else "Not Spam"

# OCR Methods
def ocr_with_paddle(img):
    ocr = PaddleOCR(lang='en', use_angle_cls=True)
    result = ocr.ocr(img)
    extracted_text = ' '.join([entry[1][0] for entry in result[0]])
    return extracted_text

def ocr_with_keras(img):
    pipeline = keras_ocr.pipeline.Pipeline()
    images = [keras_ocr.tools.read(img)]
    predictions = pipeline.recognize(images)
    extracted_text = ' '.join([text for text, _ in predictions[0]])
    return extracted_text

def ocr_with_easy(img):
    reader = easyocr.Reader(['en'])
    results = reader.readtext(img, detail=0)
    return ' '.join(results)

# OCR + Spam Detection
def process_image(ocr_method, image):
    if image is None:
        return "Error: No image uploaded."
    
    if ocr_method == "PaddleOCR":
        extracted_text = ocr_with_paddle(image)
    elif ocr_method == "KerasOCR":
        extracted_text = ocr_with_keras(image)
    elif ocr_method == "EasyOCR":
        extracted_text = ocr_with_easy(image)
    else:
        return "Invalid OCR method."

    if not extracted_text.strip():
        return "No text detected in the image."

    classification = classify_text(extracted_text)
    return f"Extracted Text: {extracted_text}\n\nClassification: {classification}"

# Gradio UI
image_input = gr.Image(type="numpy")
ocr_method_input = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR", label="OCR Method")
output_text = gr.Textbox(label="OCR & Classification Result")

interface = gr.Interface(
    fn=process_image,
    inputs=[ocr_method_input, image_input],
    outputs=output_text,
    title="OCR + Spam Detection",
    description="Upload an image with text, extract the text using OCR, and classify it as Spam or Not Spam using DistilBERT.",
    theme="compact"
)

# Launch app
if __name__ == "__main__":
    interface.launch()