make app more beautiful
Browse files
app.py
CHANGED
@@ -1,37 +1,29 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
-
import gradio as gr
|
4 |
-
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
5 |
import torch
|
6 |
import traceback
|
|
|
|
|
7 |
|
8 |
-
#
|
9 |
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
|
10 |
processor = DonutProcessor.from_pretrained(MODEL_NAME)
|
11 |
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
|
12 |
|
13 |
-
# 2) Set device and move model
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
model.to(device)
|
16 |
|
17 |
-
#
|
18 |
def ocr_donut(image):
|
19 |
try:
|
20 |
if image is None:
|
21 |
return {"error": "No image provided."}
|
22 |
-
|
23 |
-
# Prepare prompt and inputs
|
24 |
task_prompt = "<s_cord-v2>"
|
25 |
decoder_input_ids = processor.tokenizer(
|
26 |
-
task_prompt,
|
27 |
-
add_special_tokens=False,
|
28 |
-
return_tensors="pt"
|
29 |
).input_ids.to(device)
|
30 |
-
|
31 |
-
# Convert to tensor
|
32 |
pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
|
33 |
|
34 |
-
# Generate outputs
|
35 |
outputs = model.generate(
|
36 |
pixel_values,
|
37 |
decoder_input_ids=decoder_input_ids,
|
@@ -43,29 +35,77 @@ def ocr_donut(image):
|
|
43 |
return_dict_in_generate=True,
|
44 |
)
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
json_output = processor.token2json(sequence)
|
51 |
-
|
52 |
-
return {"result": json_output}
|
53 |
|
54 |
except Exception:
|
55 |
tb = traceback.format_exc()
|
56 |
print(tb)
|
57 |
return {"error": tb}
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
#
|
69 |
demo.launch(
|
70 |
server_name="0.0.0.0",
|
71 |
server_port=int(os.environ.get("PORT", 7860)),
|
|
|
1 |
import os
|
2 |
import re
|
|
|
|
|
3 |
import torch
|
4 |
import traceback
|
5 |
+
import gradio as gr
|
6 |
+
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
7 |
|
8 |
+
# βββ λͺ¨λΈ λ‘λ© βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
9 |
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
|
10 |
processor = DonutProcessor.from_pretrained(MODEL_NAME)
|
11 |
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
|
12 |
|
|
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
model.to(device)
|
15 |
|
16 |
+
# βββ OCR ν¨μ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
17 |
def ocr_donut(image):
|
18 |
try:
|
19 |
if image is None:
|
20 |
return {"error": "No image provided."}
|
|
|
|
|
21 |
task_prompt = "<s_cord-v2>"
|
22 |
decoder_input_ids = processor.tokenizer(
|
23 |
+
task_prompt, add_special_tokens=False, return_tensors="pt"
|
|
|
|
|
24 |
).input_ids.to(device)
|
|
|
|
|
25 |
pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
|
26 |
|
|
|
27 |
outputs = model.generate(
|
28 |
pixel_values,
|
29 |
decoder_input_ids=decoder_input_ids,
|
|
|
35 |
return_dict_in_generate=True,
|
36 |
)
|
37 |
|
38 |
+
seq = processor.batch_decode(outputs.sequences)[0]
|
39 |
+
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
40 |
+
seq = re.sub(r"<.*?>", "", seq, count=1).strip()
|
41 |
+
return {"result": processor.token2json(seq)}
|
|
|
|
|
|
|
42 |
|
43 |
except Exception:
|
44 |
tb = traceback.format_exc()
|
45 |
print(tb)
|
46 |
return {"error": tb}
|
47 |
|
48 |
+
# βββ CSS μ€νμΌλ§ ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
49 |
+
custom_css = """
|
50 |
+
body { background: #f0f2f5; font-family: 'Segoe UI', Tahoma, sans-serif; }
|
51 |
+
.gradio-container { max-width: 900px; margin: 40px auto; padding: 20px; }
|
52 |
+
.header { text-align: center; margin-bottom: 30px; }
|
53 |
+
.header h1 { font-size: 2.8rem; color: #333; margin: 0; }
|
54 |
+
.header p { color: #666; margin-top: 8px; }
|
55 |
+
|
56 |
+
.input-box, .output-box {
|
57 |
+
background: #fff;
|
58 |
+
border-radius: 8px;
|
59 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
60 |
+
padding: 20px;
|
61 |
+
}
|
62 |
+
.input-box { margin-right: 10px; }
|
63 |
+
.output-box { margin-left: 10px; }
|
64 |
+
|
65 |
+
.gr-button {
|
66 |
+
background: #5a8dee !important;
|
67 |
+
color: #fff !important;
|
68 |
+
border-radius: 6px !important;
|
69 |
+
padding: 10px 20px !important;
|
70 |
+
font-size: 1rem !important;
|
71 |
+
margin-top: 10px !important;
|
72 |
+
transition: background 0.2s ease;
|
73 |
+
}
|
74 |
+
.gr-button:hover { background: #3f6fcc !important; }
|
75 |
+
|
76 |
+
.footer {
|
77 |
+
text-align: center;
|
78 |
+
margin-top: 30px;
|
79 |
+
color: #999;
|
80 |
+
font-size: 0.85rem;
|
81 |
+
}
|
82 |
+
"""
|
83 |
+
|
84 |
+
# βββ Blocks λ μ΄μμ ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
85 |
+
with gr.Blocks(css=custom_css, title="Donut OCR App") as demo:
|
86 |
+
# ν€λ
|
87 |
+
with gr.HTML(elem_classes="header"):
|
88 |
+
gr.HTML("""
|
89 |
+
<h1>π Donut OCR</h1>
|
90 |
+
<p>Industrial AI Engineering Week 8 Assignment</p>
|
91 |
+
""")
|
92 |
+
|
93 |
+
# μ
λ ₯/μΆλ ₯ μμ
|
94 |
+
with gr.Row():
|
95 |
+
with gr.Column(elem_classes="input-box"):
|
96 |
+
image_input = gr.Image(type="pil", label="Upload Document Image")
|
97 |
+
run_btn = gr.Button("Run OCR", elem_id="run-btn")
|
98 |
+
with gr.Column(elem_classes="output-box"):
|
99 |
+
result_box = gr.JSON(label="Output")
|
100 |
+
|
101 |
+
# λ²νΌ ν΄λ¦ μ°κ²°
|
102 |
+
run_btn.click(fn=ocr_donut, inputs=image_input, outputs=result_box)
|
103 |
+
|
104 |
+
# νΈν°
|
105 |
+
with gr.HTML(elem_classes="footer"):
|
106 |
+
gr.HTML("<p>Powered by Naver Clova Donut β’ Built with π by You</p>")
|
107 |
|
108 |
+
# Spaces μ€ν
|
109 |
demo.launch(
|
110 |
server_name="0.0.0.0",
|
111 |
server_port=int(os.environ.get("PORT", 7860)),
|