JJS0321's picture
make app more beautiful
ce9c63c
raw
history blame contribute delete
4.29 kB
import os
import re
import torch
import traceback
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
# ─── λͺ¨λΈ λ‘œλ”© ─────────────────────────────────────────────────────────
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
processor = DonutProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# ─── OCR ν•¨μˆ˜ ──────────────────────────────────────────────────────────
def ocr_donut(image):
try:
if image is None:
return {"error": "No image provided."}
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids.to(device)
pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.config.decoder.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip()
return {"result": processor.token2json(seq)}
except Exception:
tb = traceback.format_exc()
print(tb)
return {"error": tb}
# ─── CSS μŠ€νƒ€μΌλ§ ────────────────────────────────────────────────────
custom_css = """
body { background: #f0f2f5; font-family: 'Segoe UI', Tahoma, sans-serif; }
.gradio-container { max-width: 900px; margin: 40px auto; padding: 20px; }
.header { text-align: center; margin-bottom: 30px; }
.header h1 { font-size: 2.8rem; color: #333; margin: 0; }
.header p { color: #666; margin-top: 8px; }
.input-box, .output-box {
background: #fff;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
padding: 20px;
}
.input-box { margin-right: 10px; }
.output-box { margin-left: 10px; }
.gr-button {
background: #5a8dee !important;
color: #fff !important;
border-radius: 6px !important;
padding: 10px 20px !important;
font-size: 1rem !important;
margin-top: 10px !important;
transition: background 0.2s ease;
}
.gr-button:hover { background: #3f6fcc !important; }
.footer {
text-align: center;
margin-top: 30px;
color: #999;
font-size: 0.85rem;
}
"""
# ─── Blocks λ ˆμ΄μ•„μ›ƒ ──────────────────────────────────────────────────
with gr.Blocks(css=custom_css, title="Donut OCR App") as demo:
# 헀더
gr.HTML(
"""
<div class="header">
<h1>πŸ“„ Donut OCR</h1>
<p>Industrial AI Engineering Week 8 Assignment</p>
</div>
"""
)
# μž…λ ₯/좜λ ₯ μ˜μ—­
with gr.Row():
with gr.Column(elem_classes="input-box"):
image_input = gr.Image(type="pil", label="Upload Document Image")
run_btn = gr.Button("Run OCR", elem_id="run-btn")
with gr.Column(elem_classes="output-box"):
result_box = gr.JSON(label="Output")
# λ²„νŠΌ 클릭 μ—°κ²°
run_btn.click(fn=ocr_donut, inputs=image_input, outputs=result_box)
# ν‘Έν„°
gr.HTML(
"""
<div class="footer">
<p>Powered by Naver Clova Donut</p>
</div>
"""
)
# Spaces μ‹€ν–‰
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
debug=True
)