import os import re import torch import traceback import gradio as gr from transformers import DonutProcessor, VisionEncoderDecoderModel # ─── 모델 로딩 ───────────────────────────────────────────────────────── MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2" processor = DonutProcessor.from_pretrained(MODEL_NAME) model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) # ─── OCR 함수 ────────────────────────────────────────────────────────── def ocr_donut(image): try: if image is None: return {"error": "No image provided."} task_prompt = "" decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" ).input_ids.to(device) pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device) outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.config.decoder.max_position_embeddings, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) seq = processor.batch_decode(outputs.sequences)[0] seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") seq = re.sub(r"<.*?>", "", seq, count=1).strip() return {"result": processor.token2json(seq)} except Exception: tb = traceback.format_exc() print(tb) return {"error": tb} # ─── CSS 스타일링 ──────────────────────────────────────────────────── custom_css = """ body { background: #f0f2f5; font-family: 'Segoe UI', Tahoma, sans-serif; } .gradio-container { max-width: 900px; margin: 40px auto; padding: 20px; } .header { text-align: center; margin-bottom: 30px; } .header h1 { font-size: 2.8rem; color: #333; margin: 0; } .header p { color: #666; margin-top: 8px; } .input-box, .output-box { background: #fff; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); padding: 20px; } .input-box { margin-right: 10px; } .output-box { margin-left: 10px; } .gr-button { background: #5a8dee !important; color: #fff !important; border-radius: 6px !important; padding: 10px 20px !important; font-size: 1rem !important; margin-top: 10px !important; transition: background 0.2s ease; } .gr-button:hover { background: #3f6fcc !important; } .footer { text-align: center; margin-top: 30px; color: #999; font-size: 0.85rem; } """ # ─── Blocks 레이아웃 ────────────────────────────────────────────────── with gr.Blocks(css=custom_css, title="Donut OCR App") as demo: # 헤더 gr.HTML( """

📄 Donut OCR

Industrial AI Engineering Week 8 Assignment

""" ) # 입력/출력 영역 with gr.Row(): with gr.Column(elem_classes="input-box"): image_input = gr.Image(type="pil", label="Upload Document Image") run_btn = gr.Button("Run OCR", elem_id="run-btn") with gr.Column(elem_classes="output-box"): result_box = gr.JSON(label="Output") # 버튼 클릭 연결 run_btn.click(fn=ocr_donut, inputs=image_input, outputs=result_box) # 푸터 gr.HTML( """ """ ) # Spaces 실행 demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True )