JJS0321 commited on
Commit
58a66e8
Β·
1 Parent(s): c273705

make app more beautiful

Browse files
Files changed (1) hide show
  1. app.py +69 -29
app.py CHANGED
@@ -1,37 +1,29 @@
1
  import os
2
  import re
3
- import gradio as gr
4
- from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  import torch
6
  import traceback
 
 
7
 
8
- # 1) Load pretrained Donut model and processor
9
  MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
10
  processor = DonutProcessor.from_pretrained(MODEL_NAME)
11
  model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
12
 
13
- # 2) Set device and move model
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
 
17
- # 3) Inference function with debugging
18
  def ocr_donut(image):
19
  try:
20
  if image is None:
21
  return {"error": "No image provided."}
22
-
23
- # Prepare prompt and inputs
24
  task_prompt = "<s_cord-v2>"
25
  decoder_input_ids = processor.tokenizer(
26
- task_prompt,
27
- add_special_tokens=False,
28
- return_tensors="pt"
29
  ).input_ids.to(device)
30
-
31
- # Convert to tensor
32
  pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
33
 
34
- # Generate outputs
35
  outputs = model.generate(
36
  pixel_values,
37
  decoder_input_ids=decoder_input_ids,
@@ -43,29 +35,77 @@ def ocr_donut(image):
43
  return_dict_in_generate=True,
44
  )
45
 
46
- # Decode and clean up
47
- sequence = processor.batch_decode(outputs.sequences)[0]
48
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
49
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
50
- json_output = processor.token2json(sequence)
51
-
52
- return {"result": json_output}
53
 
54
  except Exception:
55
  tb = traceback.format_exc()
56
  print(tb)
57
  return {"error": tb}
58
 
59
- # 4) Build Gradio interface
60
- demo = gr.Interface(
61
- fn=ocr_donut,
62
- inputs=gr.Image(type="pil", label="Upload Document Image"),
63
- outputs=gr.JSON(label="Output"),
64
- title="Donut OCR Gradio App",
65
- description="Upload a document image and get structured JSON output. Errors will be shown for debugging."
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # 5) Launch for Spaces
69
  demo.launch(
70
  server_name="0.0.0.0",
71
  server_port=int(os.environ.get("PORT", 7860)),
 
1
  import os
2
  import re
 
 
3
  import torch
4
  import traceback
5
+ import gradio as gr
6
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
7
 
8
+ # ─── λͺ¨λΈ λ‘œλ”© ─────────────────────────────────────────────────────────
9
  MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"
10
  processor = DonutProcessor.from_pretrained(MODEL_NAME)
11
  model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
16
+ # ─── OCR ν•¨μˆ˜ ──────────────────────────────────────────────────────────
17
  def ocr_donut(image):
18
  try:
19
  if image is None:
20
  return {"error": "No image provided."}
 
 
21
  task_prompt = "<s_cord-v2>"
22
  decoder_input_ids = processor.tokenizer(
23
+ task_prompt, add_special_tokens=False, return_tensors="pt"
 
 
24
  ).input_ids.to(device)
 
 
25
  pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device)
26
 
 
27
  outputs = model.generate(
28
  pixel_values,
29
  decoder_input_ids=decoder_input_ids,
 
35
  return_dict_in_generate=True,
36
  )
37
 
38
+ seq = processor.batch_decode(outputs.sequences)[0]
39
+ seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
40
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip()
41
+ return {"result": processor.token2json(seq)}
 
 
 
42
 
43
  except Exception:
44
  tb = traceback.format_exc()
45
  print(tb)
46
  return {"error": tb}
47
 
48
+ # ─── CSS μŠ€νƒ€μΌλ§ ────────────────────────────────────────────────────
49
+ custom_css = """
50
+ body { background: #f0f2f5; font-family: 'Segoe UI', Tahoma, sans-serif; }
51
+ .gradio-container { max-width: 900px; margin: 40px auto; padding: 20px; }
52
+ .header { text-align: center; margin-bottom: 30px; }
53
+ .header h1 { font-size: 2.8rem; color: #333; margin: 0; }
54
+ .header p { color: #666; margin-top: 8px; }
55
+
56
+ .input-box, .output-box {
57
+ background: #fff;
58
+ border-radius: 8px;
59
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
60
+ padding: 20px;
61
+ }
62
+ .input-box { margin-right: 10px; }
63
+ .output-box { margin-left: 10px; }
64
+
65
+ .gr-button {
66
+ background: #5a8dee !important;
67
+ color: #fff !important;
68
+ border-radius: 6px !important;
69
+ padding: 10px 20px !important;
70
+ font-size: 1rem !important;
71
+ margin-top: 10px !important;
72
+ transition: background 0.2s ease;
73
+ }
74
+ .gr-button:hover { background: #3f6fcc !important; }
75
+
76
+ .footer {
77
+ text-align: center;
78
+ margin-top: 30px;
79
+ color: #999;
80
+ font-size: 0.85rem;
81
+ }
82
+ """
83
+
84
+ # ─── Blocks λ ˆμ΄μ•„μ›ƒ ──────────────────────────────────────────────────
85
+ with gr.Blocks(css=custom_css, title="Donut OCR App") as demo:
86
+ # 헀더
87
+ with gr.HTML(elem_classes="header"):
88
+ gr.HTML("""
89
+ <h1>πŸ“„ Donut OCR</h1>
90
+ <p>Industrial AI Engineering Week 8 Assignment</p>
91
+ """)
92
+
93
+ # μž…λ ₯/좜λ ₯ μ˜μ—­
94
+ with gr.Row():
95
+ with gr.Column(elem_classes="input-box"):
96
+ image_input = gr.Image(type="pil", label="Upload Document Image")
97
+ run_btn = gr.Button("Run OCR", elem_id="run-btn")
98
+ with gr.Column(elem_classes="output-box"):
99
+ result_box = gr.JSON(label="Output")
100
+
101
+ # λ²„νŠΌ 클릭 μ—°κ²°
102
+ run_btn.click(fn=ocr_donut, inputs=image_input, outputs=result_box)
103
+
104
+ # ν‘Έν„°
105
+ with gr.HTML(elem_classes="footer"):
106
+ gr.HTML("<p>Powered by Naver Clova Donut β€’ Built with πŸ’œ by You</p>")
107
 
108
+ # Spaces μ‹€ν–‰
109
  demo.launch(
110
  server_name="0.0.0.0",
111
  server_port=int(os.environ.get("PORT", 7860)),