gagan3012 commited on
Commit
3eb719d
·
1 Parent(s): af367a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from streamlit_cropper import st_cropper
3
  from PIL import Image
4
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor
5
  import torch
6
  import re
7
  import pytesseract
@@ -58,6 +58,24 @@ def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2
58
  sequence = re.sub(r"<.*?>", "", sequence).strip()
59
  return sequence
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def predict_tesseract(img):
62
  text = pytesseract.image_to_string(Image.open(img))
63
  return text
@@ -96,7 +114,7 @@ Lng = st.sidebar.selectbox(label="Language", options=[
96
 
97
  Models = {
98
  "Arabic": "Qalam",
99
- "English": "Donut",
100
  "French": "Tesseract",
101
  "Korean": "Donut",
102
  "Chinese": "Donut"
@@ -138,7 +156,7 @@ if img_file:
138
  text_file = BytesIO(ocr_text.encode())
139
  st.download_button('Download Text', text_file, file_name='ocr_text.txt')
140
  elif Lng == "English":
141
- ocr_text = predict_english(cropped_img)
142
  st.subheader(f"OCR Results for {Lng}")
143
  st.write(ocr_text)
144
  text_file = BytesIO(ocr_text.encode())
 
1
  import streamlit as st
2
  from streamlit_cropper import st_cropper
3
  from PIL import Image
4
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor, NougatProcessor
5
  import torch
6
  import re
7
  import pytesseract
 
58
  sequence = re.sub(r"<.*?>", "", sequence).strip()
59
  return sequence
60
 
61
+ def predict_nougat(img, model_name="facebook/nougat-small"):
62
+ processor = NougatProcessor.from_pretrained(model_name)
63
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
64
+ image = img.convert("RGB")
65
+ pixel_values = processor(image, return_tensors="pt").pixel_values
66
+
67
+ # generate transcription (here we only generate 30 tokens)
68
+ outputs = model.generate(
69
+ pixel_values.to(device),
70
+ min_length=1,
71
+ max_new_tokens=1500,
72
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
73
+ )
74
+
75
+ page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
76
+ page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
77
+ return page_sequence
78
+
79
  def predict_tesseract(img):
80
  text = pytesseract.image_to_string(Image.open(img))
81
  return text
 
114
 
115
  Models = {
116
  "Arabic": "Qalam",
117
+ "English": "Nougat",
118
  "French": "Tesseract",
119
  "Korean": "Donut",
120
  "Chinese": "Donut"
 
156
  text_file = BytesIO(ocr_text.encode())
157
  st.download_button('Download Text', text_file, file_name='ocr_text.txt')
158
  elif Lng == "English":
159
+ ocr_text = predict_nougat(cropped_img)
160
  st.subheader(f"OCR Results for {Lng}")
161
  st.write(ocr_text)
162
  text_file = BytesIO(ocr_text.encode())