File size: 4,285 Bytes
7a6934f 58a66e8 7a6934f 58a66e8 7a6934f 58a66e8 7a6934f 58a66e8 7a6934f 58a66e8 7a6934f 58a66e8 07e5dc5 58a66e8 07e5dc5 ce9c63c 07e5dc5 7a6934f 58a66e8 7a6934f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import re
import torch
import traceback
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
# βββ λͺ¨λΈ λ‘λ© βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
processor = DonutProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# βββ OCR ν¨μ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def ocr_donut(image):
try:
if image is None:
return {"error": "No image provided."}
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids.to(device)
pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.config.decoder.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip()
return {"result": processor.token2json(seq)}
except Exception:
tb = traceback.format_exc()
print(tb)
return {"error": tb}
# βββ CSS μ€νμΌλ§ ββββββββββββββββββββββββββββββββββββββββββββββββββββ
custom_css = """
body { background: #f0f2f5; font-family: 'Segoe UI', Tahoma, sans-serif; }
.gradio-container { max-width: 900px; margin: 40px auto; padding: 20px; }
.header { text-align: center; margin-bottom: 30px; }
.header h1 { font-size: 2.8rem; color: #333; margin: 0; }
.header p { color: #666; margin-top: 8px; }
.input-box, .output-box {
background: #fff;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
padding: 20px;
}
.input-box { margin-right: 10px; }
.output-box { margin-left: 10px; }
.gr-button {
background: #5a8dee !important;
color: #fff !important;
border-radius: 6px !important;
padding: 10px 20px !important;
font-size: 1rem !important;
margin-top: 10px !important;
transition: background 0.2s ease;
}
.gr-button:hover { background: #3f6fcc !important; }
.footer {
text-align: center;
margin-top: 30px;
color: #999;
font-size: 0.85rem;
}
"""
# βββ Blocks λ μ΄μμ ββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(css=custom_css, title="Donut OCR App") as demo:
# ν€λ
gr.HTML(
"""
<div class="header">
<h1>π Donut OCR</h1>
<p>Industrial AI Engineering Week 8 Assignment</p>
</div>
"""
)
# μ
λ ₯/μΆλ ₯ μμ
with gr.Row():
with gr.Column(elem_classes="input-box"):
image_input = gr.Image(type="pil", label="Upload Document Image")
run_btn = gr.Button("Run OCR", elem_id="run-btn")
with gr.Column(elem_classes="output-box"):
result_box = gr.JSON(label="Output")
# λ²νΌ ν΄λ¦ μ°κ²°
run_btn.click(fn=ocr_donut, inputs=image_input, outputs=result_box)
# νΈν°
gr.HTML(
"""
<div class="footer">
<p>Powered by Naver Clova Donut</p>
</div>
"""
)
# Spaces μ€ν
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
debug=True
)
|