Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit_cropper import st_cropper
|
3 |
from PIL import Image
|
4 |
-
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor
|
5 |
import torch
|
6 |
import re
|
7 |
import pytesseract
|
@@ -58,6 +58,24 @@ def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2
|
|
58 |
sequence = re.sub(r"<.*?>", "", sequence).strip()
|
59 |
return sequence
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def predict_tesseract(img):
|
62 |
text = pytesseract.image_to_string(Image.open(img))
|
63 |
return text
|
@@ -96,7 +114,7 @@ Lng = st.sidebar.selectbox(label="Language", options=[
|
|
96 |
|
97 |
Models = {
|
98 |
"Arabic": "Qalam",
|
99 |
-
"English": "
|
100 |
"French": "Tesseract",
|
101 |
"Korean": "Donut",
|
102 |
"Chinese": "Donut"
|
@@ -138,7 +156,7 @@ if img_file:
|
|
138 |
text_file = BytesIO(ocr_text.encode())
|
139 |
st.download_button('Download Text', text_file, file_name='ocr_text.txt')
|
140 |
elif Lng == "English":
|
141 |
-
ocr_text =
|
142 |
st.subheader(f"OCR Results for {Lng}")
|
143 |
st.write(ocr_text)
|
144 |
text_file = BytesIO(ocr_text.encode())
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit_cropper import st_cropper
|
3 |
from PIL import Image
|
4 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor, NougatProcessor
|
5 |
import torch
|
6 |
import re
|
7 |
import pytesseract
|
|
|
58 |
sequence = re.sub(r"<.*?>", "", sequence).strip()
|
59 |
return sequence
|
60 |
|
61 |
+
def predict_nougat(img, model_name="facebook/nougat-small"):
|
62 |
+
processor = NougatProcessor.from_pretrained(model_name)
|
63 |
+
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
64 |
+
image = img.convert("RGB")
|
65 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
66 |
+
|
67 |
+
# generate transcription (here we only generate 30 tokens)
|
68 |
+
outputs = model.generate(
|
69 |
+
pixel_values.to(device),
|
70 |
+
min_length=1,
|
71 |
+
max_new_tokens=1500,
|
72 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
73 |
+
)
|
74 |
+
|
75 |
+
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
76 |
+
page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
|
77 |
+
return page_sequence
|
78 |
+
|
79 |
def predict_tesseract(img):
|
80 |
text = pytesseract.image_to_string(Image.open(img))
|
81 |
return text
|
|
|
114 |
|
115 |
Models = {
|
116 |
"Arabic": "Qalam",
|
117 |
+
"English": "Nougat",
|
118 |
"French": "Tesseract",
|
119 |
"Korean": "Donut",
|
120 |
"Chinese": "Donut"
|
|
|
156 |
text_file = BytesIO(ocr_text.encode())
|
157 |
st.download_button('Download Text', text_file, file_name='ocr_text.txt')
|
158 |
elif Lng == "English":
|
159 |
+
ocr_text = predict_nougat(cropped_img)
|
160 |
st.subheader(f"OCR Results for {Lng}")
|
161 |
st.write(ocr_text)
|
162 |
text_file = BytesIO(ocr_text.encode())
|