winamnd commited on
Commit
025580f
·
verified ·
1 Parent(s): ce96f23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -80
app.py CHANGED
@@ -1,35 +1,31 @@
1
  import gradio as gr
2
- import tensorflow as tf
 
3
  import keras_ocr
4
- import requests
5
  import cv2
6
- import os
7
- import csv
8
- import numpy as np
9
- import pandas as pd
10
- import huggingface_hub
11
- from huggingface_hub import Repository
12
- from datetime import datetime
13
- import scipy.ndimage.interpolation as inter
14
  import easyocr
15
- import datasets
16
- from datasets import load_dataset, Image
17
- from PIL import Image
18
  from paddleocr import PaddleOCR
19
- from save_data import flag
20
-
 
 
 
 
 
 
 
 
21
  """
22
  Paddle OCR
23
  """
24
  def ocr_with_paddle(img):
25
  finaltext = ''
26
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
27
- # img_path = 'exp.jpeg'
28
  result = ocr.ocr(img)
29
 
30
  for i in range(len(result[0])):
31
  text = result[0][i][1][0]
32
- finaltext += ' '+ text
33
  return finaltext
34
 
35
  """
@@ -37,84 +33,62 @@ Keras OCR
37
  """
38
  def ocr_with_keras(img):
39
  output_text = ''
40
- pipeline=keras_ocr.pipeline.Pipeline()
41
- images=[keras_ocr.tools.read(img)]
42
- predictions=pipeline.recognize(images)
43
- first=predictions[0]
44
- for text,box in first:
45
- output_text += ' '+ text
46
  return output_text
47
 
48
  """
49
- easy OCR
50
  """
51
- # gray scale image
52
- def get_grayscale(image):
53
- return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
54
-
55
- # Thresholding or Binarization
56
- def thresholding(src):
57
- return cv2.threshold(src,127,255, cv2.THRESH_TOZERO)[1]
58
  def ocr_with_easy(img):
59
- gray_scale_image=get_grayscale(img)
60
- thresholding(gray_scale_image)
61
- cv2.imwrite('image.png',gray_scale_image)
62
- reader = easyocr.Reader(['th','en'])
63
- bounds = reader.readtext('image.png',paragraph="False",detail = 0)
64
- bounds = ''.join(bounds)
65
- return bounds
66
-
67
  """
68
- Generate OCR
69
  """
70
- def generate_ocr(Method,img):
71
-
 
 
 
72
  text_output = ''
73
- if (img).any():
74
- add_csv = []
75
- image_id = 1
76
- print("Method___________________",Method)
77
- if Method == 'EasyOCR':
78
- text_output = ocr_with_easy(img)
79
- if Method == 'KerasOCR':
80
- text_output = ocr_with_keras(img)
81
- if Method == 'PaddleOCR':
82
- text_output = ocr_with_paddle(img)
83
-
84
- try:
85
- flag(Method,text_output,img)
86
- except Exception as e:
87
- print(e)
88
- return text_output
89
- else:
90
- raise gr.Error("Please upload an image!!!!")
91
 
92
- # except Exception as e:
93
- # print("Error in ocr generation ==>",e)
94
- # text_output = "Something went wrong"
95
- # return text_output
96
 
 
 
 
 
97
 
98
  """
99
- Create user interface for OCR demo
100
  """
101
-
102
- # image = gr.Image(shape=(300, 300))
103
  image = gr.Image()
104
- method = gr.Radio(["PaddleOCR","EasyOCR", "KerasOCR"],value="PaddleOCR")
105
- output = gr.Textbox(label="Output")
 
106
 
107
  demo = gr.Interface(
108
- generate_ocr,
109
- [method,image],
110
- output,
111
- title="Optical Character Recognition",
112
- css=".gradio-container {background-color: lightgray} #radio_div {background-color: #FFD8B4; font-size: 40px;}",
113
- article = """<p style='text-align: center;'>Feel free to give us your thoughts on this demo and please contact us at
114
- <a href="mailto:[email protected]" target="_blank">[email protected]</a>
115
- <p style='text-align: center;'>Developed by: <a href="https://www.pragnakalp.com" target="_blank">Pragnakalp Techlabs</a></p>"""
116
-
117
-
118
  )
119
- # demo.launch(enable_queue = False)
120
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
  import keras_ocr
 
5
  import cv2
 
 
 
 
 
 
 
 
6
  import easyocr
 
 
 
7
  from paddleocr import PaddleOCR
8
+ import numpy as np
9
+
10
+ # Load tokenizer
11
+ tokenizer = DistilBertTokenizer.from_pretrained("./distilbert_spam_model")
12
+
13
+ # Load model
14
+ model = DistilBertForSequenceClassification.from_pretrained("./distilbert_spam_model")
15
+ model.load_state_dict(torch.load("./distilbert_spam_model/model.pth", map_location=torch.device('cpu')))
16
+ model.eval()
17
+
18
  """
19
  Paddle OCR
20
  """
21
  def ocr_with_paddle(img):
22
  finaltext = ''
23
  ocr = PaddleOCR(lang='en', use_angle_cls=True)
 
24
  result = ocr.ocr(img)
25
 
26
  for i in range(len(result[0])):
27
  text = result[0][i][1][0]
28
+ finaltext += ' ' + text
29
  return finaltext
30
 
31
  """
 
33
  """
34
  def ocr_with_keras(img):
35
  output_text = ''
36
+ pipeline = keras_ocr.pipeline.Pipeline()
37
+ images = [keras_ocr.tools.read(img)]
38
+ predictions = pipeline.recognize(images)
39
+
40
+ for text, _ in predictions[0]:
41
+ output_text += ' ' + text
42
  return output_text
43
 
44
  """
45
+ Easy OCR
46
  """
 
 
 
 
 
 
 
47
  def ocr_with_easy(img):
48
+ reader = easyocr.Reader(['en'])
49
+ bounds = reader.readtext(img, paragraph=True, detail=0)
50
+ return ' '.join(bounds)
51
+
 
 
 
 
52
  """
53
+ Generate OCR and classify spam
54
  """
55
+ def generate_ocr_and_classify(Method, img):
56
+ if img is None:
57
+ raise gr.Error("Please upload an image!")
58
+
59
+ # Perform OCR
60
  text_output = ''
61
+ if Method == 'EasyOCR':
62
+ text_output = ocr_with_easy(img)
63
+ elif Method == 'KerasOCR':
64
+ text_output = ocr_with_keras(img)
65
+ elif Method == 'PaddleOCR':
66
+ text_output = ocr_with_paddle(img)
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # Classify extracted text
69
+ inputs = tokenizer(text_output, return_tensors="pt", truncation=True, padding=True)
70
+ with torch.no_grad():
71
+ outputs = model(**inputs)
72
 
73
+ prediction = torch.argmax(outputs.logits, dim=1).item()
74
+ classification = "Spam" if prediction == 1 else "Not Spam"
75
+
76
+ return text_output, classification
77
 
78
  """
79
+ Create user interface
80
  """
 
 
81
  image = gr.Image()
82
+ method = gr.Radio(["PaddleOCR", "EasyOCR", "KerasOCR"], value="PaddleOCR")
83
+ output_text = gr.Textbox(label="Extracted Text")
84
+ output_label = gr.Label(label="Classification")
85
 
86
  demo = gr.Interface(
87
+ generate_ocr_and_classify,
88
+ [method, image],
89
+ [output_text, output_label],
90
+ title="OCR & Spam Classification",
91
+ description="Upload an image with text, extract the text using OCR, and classify whether it is spam or not.",
 
 
 
 
 
92
  )
93
+
94
  demo.launch()