github-actions[bot] commited on
Commit
8af6af2
·
0 Parent(s):

Sync from https://github.com/ryanlinjui/menu-text-detection

Browse files
.checkpoints/.gitkeep ADDED
File without changes
.env.example ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ HUGGINGFACE_TOKEN="HUGGINGFACE_TOKEN"
2
+ GIMINI_API_TOKEN="GIMINI_API_TOKEN"
3
+ OPENAI_API_TOKEN="OPENAI_API_TOKEN"
.github/workflows/sync.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Spaces
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ jobs:
8
+ sync:
9
+ name: Sync
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Checkout Repository
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Remove bad files
16
+ run: rm -rf examples assets
17
+
18
+ - name: Sync to Hugging Face Spaces
19
+ uses: JacobLinCool/huggingface-sync@v1
20
+ with:
21
+ github: ${{ secrets.GITHUB_TOKEN }}
22
+ user: ryanlinjui # Hugging Face username or organization name
23
+ space: menu-text-detection # Hugging Face space name
24
+ token: ${{ secrets.HF_TOKEN }} # Hugging Face token
25
+ python_version: 3.11 # Python version
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mac
2
+ .DS_Store
3
+
4
+ # cache
5
+ __pycache__
6
+
7
+ # datasets
8
+ datasets
9
+
10
+ # papers
11
+ docs/papers
12
+
13
+ # uv
14
+ .venv
15
+
16
+ # gradio
17
+ .gradio
18
+
19
+ # env
20
+ .env
21
+
22
+ # checkpoint
23
+ .checkpoints/*
24
+ !.checkpoints/.gitkeep
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 RyanLin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: menu text detection
3
+ emoji: 🦄
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ python_version: 3.11
8
+ short_description: Extract structured menu information from images into JSON...
9
+ tags: [ "donut","fine-tuning","image-to-text","transformer" ]
10
+ ---
11
+
12
+ # Menu Text Detection System
13
+
14
+ Extract structured menu information from images into JSON using a fine-tuned Donut E2E model.
15
+ > Based on [Donut by Clova AI (ECCV ’22)](https://github.com/clovaai/donut)
16
+
17
+ <div align="center">
18
+
19
+ <img src="./assets/demo.gif" alt="demo" width="500"/><br>
20
+
21
+ [![Gradio Space Demo](https://img.shields.io/badge/GradioSpace-Demo-important?logo=huggingface)](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)<br>
22
+ [![Hugging Face Models & Datasets](https://img.shields.io/badge/HuggingFace-Models_&_Datasets-important?logo=huggingface)](https://huggingface.co/collections/ryanlinjui/menu-text-detection-670ccf527626bb004bbfb39b)
23
+
24
+ </div>
25
+
26
+ ## 🚀 Features
27
+ ### Overview
28
+ Currently supports the following information from menu images:
29
+
30
+ - **Restaurant Name**
31
+ - **Business Hours**
32
+ - **Address**
33
+ - **Phone Number**
34
+ - **Dish Information**
35
+ - Name
36
+ - Price
37
+
38
+ > For the JSON schema, see [tools directory](./tools).
39
+
40
+ ### Supported Methods to Extract Menu Information
41
+ - Fine-tuned Donut model
42
+ - OpenAI GPT API
43
+ - Google Gemini API
44
+
45
+ ## 💻 Training / Fine-Tuning
46
+ ### Setup
47
+ Use [uv](https://github.com/astral-sh/uv) to set up the development environment:
48
+
49
+ ```bash
50
+ uv sync
51
+ ```
52
+
53
+ ### Training Script (Datasets collecting, Fine-Tuning)
54
+ Please refer [`train.ipynb`](./train.ipynb). Use Jupyter Notebook for training:
55
+
56
+ ```bash
57
+ uv run jupyter-notebook
58
+ ```
59
+
60
+ > For VSCode users, please install Jupyter extension, then select `.venv/bin/python` as your kernel.
61
+
62
+ ### Run Demo Locally
63
+ ```bash
64
+ uv run python app.py
65
+ ```
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import numpy as np
5
+ import gradio as gr
6
+ from dotenv import load_dotenv
7
+
8
+ from menu.llm import (
9
+ GeminiAPI,
10
+ OpenAIAPI
11
+ )
12
+ from menu.donut import DonutFinetuned
13
+
14
+ load_dotenv()
15
+ GEMINI_API_TOKEN = os.getenv("GIMINI_API_TOKEN", "")
16
+ OPENAI_API_TOKEN = os.getenv("OPENAI_API_TOKEN", "")
17
+
18
+ SOURCE_CODE_GH_URL = "https://github.com/ryanlinjui/menu-text-detection"
19
+ BADGE_URL = "https://img.shields.io/badge/GitHub_Code-Click_Here!!-default?logo=github"
20
+
21
+ GITHUB_RAW_URL = "https://raw.githubusercontent.com/ryanlinjui/menu-text-detection/main"
22
+ EXAMPLE_IMAGE_LIST = [
23
+ f"{GITHUB_RAW_URL}/examples/menu-hd.jpg",
24
+ f"{GITHUB_RAW_URL}/examples/menu-vs.jpg",
25
+ f"{GITHUB_RAW_URL}/examples/menu-si.jpg"
26
+ ]
27
+ MODEL_LIST = [
28
+ "Donut Model",
29
+ "gemini-2.0-flash",
30
+ "gemini-2.5-flash-preview-04-17",
31
+ "gemini-2.5-pro-preview-03-25",
32
+ "gpt-4.1",
33
+ "gpt-4o",
34
+ "o4-mini"
35
+ ]
36
+
37
+ def handle(image: np.ndarray, model: str, api_token: str) -> str:
38
+ if image is None:
39
+ raise gr.Error("Please upload an image first.")
40
+
41
+ if model == MODEL_LIST[0]:
42
+ result = DonutFinetuned.predict(image)
43
+
44
+ elif model in MODEL_LIST[1:]:
45
+ if len(api_token) < 10:
46
+ raise gr.Error(f"Please provide a valid token for {model}.")
47
+ try:
48
+ if model in MODEL_LIST[1:4]:
49
+ result = GeminiAPI.call(image, model, api_token)
50
+ else:
51
+ result = OpenAIAPI.call(image, model, api_token)
52
+ except Exception as e:
53
+ raise gr.Error(f"Failed to process with API model {model}: {str(e)}")
54
+ else:
55
+ raise gr.Error("Invalid model selection. Please choose a valid model.")
56
+
57
+ return json.dumps(result, indent=4, ensure_ascii=False)
58
+
59
+ def UserInterface() -> gr.Interface:
60
+ with gr.Blocks(
61
+ delete_cache=(86400, 86400),
62
+ css="""
63
+ .image-panel {
64
+ display: flex;
65
+ flex-direction: column;
66
+ height: 600px;
67
+ }
68
+ .image-panel img {
69
+ object-fit: contain;
70
+ max-height: 600px;
71
+ max-width: 600px;
72
+ width: 100%;
73
+ }
74
+ .large-text textarea {
75
+ font-size: 20px !important;
76
+ height: 600px !important;
77
+ width: 100% !important;
78
+ }
79
+ """
80
+ ) as gradio_interface:
81
+ gr.HTML(f'<a href="{SOURCE_CODE_GH_URL}"><img src="{BADGE_URL}" alt="GitHub Code"/></a>')
82
+ gr.Markdown("# Menu Text Detection")
83
+
84
+ with gr.Row():
85
+ with gr.Column(scale=1, min_width=500):
86
+ gr.Markdown("## 📷 Menu Image")
87
+ menu_image = gr.Image(
88
+ type="numpy",
89
+ label="Input menu image",
90
+ elem_classes="image-panel"
91
+ )
92
+
93
+ gr.Markdown("## 🤖 Model Selection")
94
+ model_choice_dropdown = gr.Dropdown(
95
+ choices=MODEL_LIST,
96
+ value=MODEL_LIST[0],
97
+ label="Select Text Detection Model"
98
+ )
99
+
100
+ api_token_textbox = gr.Textbox(
101
+ label="API Token",
102
+ placeholder="Enter your API token here...",
103
+ type="password",
104
+ visible=False
105
+ )
106
+
107
+ generate_button = gr.Button("Generate Menu Information", variant="primary")
108
+
109
+ gr.Examples(
110
+ examples=EXAMPLE_IMAGE_LIST,
111
+ inputs=menu_image,
112
+ label="Example Menu Images"
113
+ )
114
+
115
+ with gr.Column(scale=1):
116
+ gr.Markdown("## 🍽️ Menu Info")
117
+ menu_json_textbox = gr.Textbox(
118
+ label="Ouput JSON",
119
+ interactive=False,
120
+ text_align="left",
121
+ elem_classes="large-text"
122
+ )
123
+
124
+ def update_token_visibility(choice):
125
+ if choice in MODEL_LIST[1:]:
126
+ current_token = ""
127
+ if choice in MODEL_LIST[1:4]:
128
+ current_token = GEMINI_API_TOKEN
129
+ elif choice in MODEL_LIST[4:]:
130
+ current_token = OPENAI_API_TOKEN
131
+ return gr.update(visible=True, value=current_token)
132
+ else:
133
+ return gr.update(visible=False)
134
+
135
+ model_choice_dropdown.change(
136
+ fn=update_token_visibility,
137
+ inputs=model_choice_dropdown,
138
+ outputs=api_token_textbox
139
+ )
140
+
141
+ generate_button.click(
142
+ fn=handle,
143
+ inputs=[menu_image, model_choice_dropdown, api_token_textbox],
144
+ outputs=menu_json_textbox
145
+ )
146
+
147
+ return gradio_interface
148
+
149
+ if __name__ == "__main__":
150
+ demo = UserInterface()
151
+ demo.launch()
menu/donut.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ from datasets import DatasetDict
7
+ from torch.utils.data import Dataset
8
+ from transformers import pipeline, DonutProcessor
9
+
10
+ class DonutFinetuned:
11
+ DEFAULT_PIPELINE = pipeline(
12
+ task="image-to-text",
13
+ model="naver-clova-ix/donut-base"
14
+ )
15
+ @classmethod
16
+ def predict(cls, image: np.ndarray) -> dict:
17
+ image = Image.fromarray(image)
18
+ result = cls.DEFAULT_PIPELINE(image)
19
+ return result
20
+
21
+ class DonutDatasets:
22
+ """
23
+ Modified from:
24
+ https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb
25
+
26
+ Donut PyTorch Dataset Wrapper (supports train/validation/test splits)
27
+ - Dynamic field names and JSON-to-token conversion
28
+ - Returns PyTorch Datasets with __getitem__ producing tensors
29
+ - Splits controlled by train_split/validation_split/test_split
30
+ - Only single JSON annotation supported
31
+ - Supports subscripting: datasets["train"], datasets["validation"], datasets["test"]
32
+ Args:
33
+ - datasets: DatasetDict containing train/validation/test splits
34
+ - processor: DonutProcessor for image processing
35
+ - image_column: Column name for images in the dataset
36
+ - annotation_column: Column name for annotations in the dataset
37
+ - task_start_token: Token to start the task
38
+ - prompt_end_token: Token to end the prompt
39
+ - max_length: Maximum length of tokenized sequences
40
+ - train_split: Fraction of data to use for training (0.0-1.0)
41
+ - validation_split: Fraction of data to use for validation (0.0-1.0)
42
+ - test_split: Fraction of data to use for testing (0.0-1.0)
43
+ - ignore_index: Index to ignore in labels (default: -100)
44
+ - sort_json_key: Whether to sort JSON keys (default: True)
45
+ - seed: Random seed for reproducibility. If None, use OS random seed (default: None)
46
+ - shuffle: Whether to shuffle the dataset (default: True)
47
+ Returns:
48
+ - DonutDatasets object with train/validation/test splits
49
+ Example:
50
+ datasets = DonutDatasets(
51
+ datasets=dataset_dict,
52
+ processor=processor,
53
+ image_column="image",
54
+ annotation_column="annotation",
55
+ task_start_token="<s_task>",
56
+ prompt_end_token="<s_prompt>",
57
+ max_length=512,
58
+ train_split=0.8,
59
+ validation_split=0.1,
60
+ test_split=0.1
61
+ )
62
+ train_dataset = datasets["train"]
63
+ validation_dataset = datasets["validation"]
64
+ test_dataset = datasets["test"]
65
+ Note:
66
+ - The dataset must be a DatasetDict with train/validation/test splits
67
+ - The processor must be a DonutProcessor instance
68
+ - The image_column and annotation_column must exist in the dataset
69
+ - The task_start_token and prompt_end_token must be unique tokens
70
+ - The max_length should be set according to the model's maximum input length
71
+ - The ignore_index is used for padding in labels (default: -100)
72
+ - The sort_json_key option determines whether JSON keys are sorted or not
73
+ """
74
+ def __init__(
75
+ self,
76
+ datasets: DatasetDict,
77
+ processor: DonutProcessor,
78
+ image_column: str,
79
+ annotation_column: str,
80
+ task_start_token: str,
81
+ prompt_end_token: str,
82
+ max_length: int = 512,
83
+ train_split: float = 1.0,
84
+ validation_split: float = 0.0,
85
+ test_split: float = 0.0,
86
+ ignore_index: int = -100,
87
+ sort_json_key: bool = True,
88
+ seed: Optional[int] = None,
89
+ shuffle: bool = True
90
+ ):
91
+ assert abs(train_split + validation_split + test_split - 1.0) < 1e-6, (
92
+ "train/validation/test splits must sum to 1"
93
+ )
94
+ self.processor = processor
95
+ self.tokenizer = processor.tokenizer
96
+ self.image_column = image_column
97
+ self.annotation_column = annotation_column
98
+ self.max_length = max_length
99
+ self.task_start_token = task_start_token
100
+ self.prompt_end_token = prompt_end_token or task_start_token
101
+ self.ignore_index = ignore_index
102
+ self.sort_json_key = sort_json_key
103
+
104
+ # Perform split on provided datasets
105
+ raw = datasets
106
+ parts: Dict[str, Any] = {}
107
+ if train_split < 1.0:
108
+ split1 = raw["train"].train_test_split(test_size=1 - train_split, seed=seed, shuffle=shuffle)
109
+ parts["train"] = split1["train"]
110
+ rest = split1["test"]
111
+ if validation_split > 0:
112
+ val_frac = validation_split / (validation_split + test_split)
113
+ split2 = rest.train_test_split(test_size=1 - val_frac, seed=seed, shuffle=shuffle)
114
+ parts["validation"] = split2["train"]
115
+ parts["test"] = split2["test"]
116
+ else:
117
+ parts["test"] = rest
118
+ else:
119
+ parts = dict(raw)
120
+
121
+ # Create individual split datasets
122
+ self._splits: Dict[str, Dataset] = {}
123
+ for name, ds in parts.items():
124
+ self._splits[name] = _SplitDataset(
125
+ hf_dataset=ds,
126
+ processor=self.processor,
127
+ image_column=self.image_column,
128
+ annotation_column=self.annotation_column,
129
+ max_length=self.max_length,
130
+ ignore_index=self.ignore_index,
131
+ sort_json_key=self.sort_json_key,
132
+ task_start_token=self.task_start_token,
133
+ prompt_end_token=self.prompt_end_token,
134
+ )
135
+
136
+ def __getitem__(self, split: str) -> Dataset:
137
+ """
138
+ Return the dataset split by name, e.g., datasets["train"]
139
+ """
140
+ if split in self._splits:
141
+ return self._splits[split]
142
+ raise KeyError(f"Unknown split '{split}'. Available splits: {list(self._splits.keys())}")
143
+
144
+ def __repr__(self):
145
+ return f"DonutDatasets(splits={list(self._splits.keys())})"
146
+
147
+
148
+ class _SplitDataset(Dataset):
149
+ """
150
+ PyTorch Dataset for a single split, returns (pixel_values, labels, target_sequence)
151
+ """
152
+ def __init__(
153
+ self,
154
+ hf_dataset,
155
+ processor: DonutProcessor,
156
+ image_column: str,
157
+ annotation_column: str,
158
+ max_length: int,
159
+ ignore_index: int,
160
+ sort_json_key: bool,
161
+ task_start_token: str,
162
+ prompt_end_token: str,
163
+ ):
164
+ self.processor = processor
165
+ self.tokenizer = processor.tokenizer
166
+ self.hf_dataset = hf_dataset
167
+ self.image_column = image_column
168
+ self.annotation_column = annotation_column
169
+ self.max_length = max_length
170
+ self.ignore_index = ignore_index
171
+ self.sort_json_key = sort_json_key
172
+ self.task_start_token = task_start_token
173
+ self.prompt_end_token = prompt_end_token
174
+
175
+ # Prepare tokenized ground-truth sequences (single annotation)
176
+ self.gt_token_sequences = []
177
+ for sample in self.hf_dataset:
178
+ gt = sample[self.annotation_column]
179
+ if isinstance(gt, str):
180
+ gt = json.loads(gt)
181
+ seq = self._json_to_token(gt) + self.tokenizer.eos_token
182
+ self.gt_token_sequences.append(seq)
183
+
184
+ # Add special tokens to tokenizer
185
+ self.tokenizer.add_tokens([self.task_start_token, self.prompt_end_token])
186
+
187
+ def _json_to_token(self, obj: Any) -> str:
188
+ if isinstance(obj, dict):
189
+ keys = sorted(obj.keys()) if self.sort_json_key else obj.keys()
190
+ seq = ""
191
+ for k in keys:
192
+ open_tag = f"<s_{k}>"
193
+ close_tag = f"</s_{k}>"
194
+ self.tokenizer.add_special_tokens({"additional_special_tokens": [open_tag, close_tag]})
195
+ seq += open_tag + self._json_to_token(obj[k]) + close_tag
196
+ return seq
197
+ if isinstance(obj, list):
198
+ return r"<sep/>".join(self._json_to_token(x) for x in obj)
199
+ return str(obj)
200
+
201
+ def __len__(self):
202
+ return len(self.hf_dataset)
203
+
204
+ def __getitem__(self, idx: int):
205
+ sample = self.hf_dataset[idx]
206
+ pixel_values = self.processor(sample[self.image_column], return_tensors="pt").pixel_values.squeeze()
207
+ target_seq = self.gt_token_sequences[idx]
208
+ tokens = self.tokenizer(
209
+ target_seq,
210
+ add_special_tokens=False,
211
+ max_length=self.max_length,
212
+ padding="max_length",
213
+ truncation=True,
214
+ return_tensors="pt",
215
+ )
216
+ input_ids = tokens.input_ids.squeeze(0)
217
+ labels = input_ids.clone()
218
+ labels[labels == self.tokenizer.pad_token_id] = self.ignore_index
219
+ return {
220
+ "pixel_values": pixel_values,
221
+ "input_ids": input_ids,
222
+ "attention_mask": tokens.attention_mask.squeeze(0),
223
+ "labels": labels,
224
+ "target_sequence": target_seq
225
+ }
menu/llm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .gemini import GeminiAPI
2
+ from .openai import OpenAIAPI
menu/llm/base.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+
5
+ class LLMBase(ABC):
6
+ @classmethod
7
+ @abstractmethod
8
+ def call(image: np.ndarray, model: str, token: str) -> dict:
9
+ raise NotImplementedError
menu/llm/gemini.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from google import genai
6
+ from google.genai import types
7
+
8
+ from .base import LLMBase
9
+
10
+ FUNCTION_CALL = json.load(open("tools/schema_gemini.json", "r"))
11
+
12
+ class GeminiAPI(LLMBase):
13
+ @classmethod
14
+ def call(cls, image: np.ndarray, model: str, token: str) -> dict:
15
+ client = genai.Client(api_key=token) # Initialize the client with the API key
16
+ encode_img = Image.fromarray(image) # Convert the image for the API
17
+
18
+ config = types.GenerateContentConfig(
19
+ tools=[types.Tool(function_declarations=[FUNCTION_CALL])],
20
+ tool_config={
21
+ "function_calling_config": {
22
+ "mode": "ANY",
23
+ "allowed_function_names": [FUNCTION_CALL["name"]]
24
+ }
25
+ }
26
+ )
27
+ response = client.models.generate_content(
28
+ model=model,
29
+ contents=[encode_img],
30
+ config=config
31
+ )
32
+ if response.candidates[0].content.parts[0].function_call:
33
+ function_call = response.candidates[0].content.parts[0].function_call
34
+ return function_call.args
35
+
36
+ return {}
menu/llm/openai.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import base64
3
+ from io import BytesIO
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from openai import OpenAI
8
+
9
+ from .base import LLMBase
10
+
11
+ FUNCTION_CALL = json.load(open("tools/schema_openai.json", "r"))
12
+
13
+ class OpenAIAPI(LLMBase):
14
+ @classmethod
15
+ def call(cls, image: np.ndarray, model: str, token: str) -> dict:
16
+ client = OpenAI(api_key=token) # Initialize the client with the API key
17
+ buffer = BytesIO()
18
+ Image.fromarray(image).save(buffer, format="JPEG")
19
+ encode_img = base64.b64encode(buffer.getvalue()).decode("utf-8") # Convert the image for the API
20
+
21
+ response = client.responses.create(
22
+ model=model,
23
+ input=[
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {
28
+ "type": "input_image",
29
+ "image_url": f"data:image/jpeg;base64,{encode_img}",
30
+ },
31
+ ],
32
+ }
33
+ ],
34
+ tools=[FUNCTION_CALL],
35
+ )
36
+ if response and response.output:
37
+ if hasattr(response.output[0], "arguments"):
38
+ return json.loads(response.output[0].arguments)
39
+ return {}
pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ authors = [{name = "ryanlinjui", email = "[email protected]"}]
3
+ name = "menu-text-detection"
4
+ version = "0.1.0"
5
+ description = "Extract structured menu information from images into JSON using a fine-tuned Donut E2E model."
6
+ readme = "README.md"
7
+ requires-python = ">=3.11"
8
+ dependencies = [
9
+ "accelerate>=1.6.0",
10
+ "datasets>=3.6.0",
11
+ "dotenv>=0.9.9",
12
+ "google-genai>=1.14.0",
13
+ "gradio>=5.29.0",
14
+ "huggingface-hub>=0.31.1",
15
+ "matplotlib>=3.10.1",
16
+ "notebook>=7.4.2",
17
+ "openai>=1.77.0",
18
+ "pillow>=11.2.1",
19
+ "protobuf>=6.30.2",
20
+ "sentencepiece>=0.2.0",
21
+ "tensorboardx>=2.6.2.2",
22
+ "transformers>=4.51.3",
23
+ ]
requirements.txt ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.18
5
+ aiosignal==1.3.2
6
+ annotated-types==0.7.0
7
+ anyio==4.9.0
8
+ appnope==0.1.4
9
+ argon2-cffi==23.1.0
10
+ argon2-cffi-bindings==21.2.0
11
+ arrow==1.3.0
12
+ asttokens==3.0.0
13
+ async-lru==2.0.5
14
+ attrs==25.3.0
15
+ babel==2.17.0
16
+ beautifulsoup4==4.13.4
17
+ bleach==6.2.0
18
+ cachetools==5.5.2
19
+ certifi==2025.4.26
20
+ cffi==1.17.1
21
+ charset-normalizer==3.4.2
22
+ click==8.1.8
23
+ comm==0.2.2
24
+ contourpy==1.3.2
25
+ cycler==0.12.1
26
+ datasets==3.6.0
27
+ debugpy==1.8.14
28
+ decorator==5.2.1
29
+ defusedxml==0.7.1
30
+ dill==0.3.8
31
+ distro==1.9.0
32
+ dotenv==0.9.9
33
+ executing==2.2.0
34
+ fastapi==0.115.12
35
+ fastjsonschema==2.21.1
36
+ ffmpy==0.5.0
37
+ filelock==3.18.0
38
+ fonttools==4.57.0
39
+ fqdn==1.5.1
40
+ frozenlist==1.6.0
41
+ fsspec==2025.3.0
42
+ google-auth==2.40.1
43
+ google-genai==1.14.0
44
+ gradio==5.29.0
45
+ gradio-client==1.10.0
46
+ groovy==0.1.2
47
+ h11==0.16.0
48
+ hf-xet==1.1.0
49
+ httpcore==1.0.9
50
+ httpx==0.28.1
51
+ huggingface-hub==0.31.1
52
+ idna==3.10
53
+ ipykernel==6.29.5
54
+ ipython==9.2.0
55
+ ipython-pygments-lexers==1.1.1
56
+ isoduration==20.11.0
57
+ jedi==0.19.2
58
+ jinja2==3.1.6
59
+ jiter==0.9.0
60
+ json5==0.12.0
61
+ jsonpointer==3.0.0
62
+ jsonschema==4.23.0
63
+ jsonschema-specifications==2025.4.1
64
+ jupyter-client==8.6.3
65
+ jupyter-core==5.7.2
66
+ jupyter-events==0.12.0
67
+ jupyter-lsp==2.2.5
68
+ jupyter-server==2.15.0
69
+ jupyter-server-terminals==0.5.3
70
+ jupyterlab==4.4.2
71
+ jupyterlab-pygments==0.3.0
72
+ jupyterlab-server==2.27.3
73
+ kiwisolver==1.4.8
74
+ markdown-it-py==3.0.0
75
+ markupsafe==3.0.2
76
+ matplotlib==3.10.1
77
+ matplotlib-inline==0.1.7
78
+ mdurl==0.1.2
79
+ mistune==3.1.3
80
+ mpmath==1.3.0
81
+ multidict==6.4.3
82
+ multiprocess==0.70.16
83
+ nbclient==0.10.2
84
+ nbconvert==7.16.6
85
+ nbformat==5.10.4
86
+ nest-asyncio==1.6.0
87
+ networkx==3.4.2
88
+ notebook==7.4.2
89
+ notebook-shim==0.2.4
90
+ numpy==2.2.5
91
+ openai==1.77.0
92
+ orjson==3.10.18
93
+ overrides==7.7.0
94
+ packaging==25.0
95
+ pandas==2.2.3
96
+ pandocfilters==1.5.1
97
+ parso==0.8.4
98
+ pexpect==4.9.0
99
+ pillow==11.2.1
100
+ platformdirs==4.3.8
101
+ prometheus-client==0.21.1
102
+ prompt-toolkit==3.0.51
103
+ propcache==0.3.1
104
+ protobuf==6.30.2
105
+ psutil==7.0.0
106
+ ptyprocess==0.7.0
107
+ pure-eval==0.2.3
108
+ pyarrow==20.0.0
109
+ pyasn1==0.6.1
110
+ pyasn1-modules==0.4.2
111
+ pycparser==2.22
112
+ pydantic==2.11.4
113
+ pydantic-core==2.33.2
114
+ pydub==0.25.1
115
+ pygments==2.19.1
116
+ pyparsing==3.2.3
117
+ python-dateutil==2.9.0.post0
118
+ python-dotenv==1.1.0
119
+ python-json-logger==3.3.0
120
+ python-multipart==0.0.20
121
+ pytz==2025.2
122
+ pyyaml==6.0.2
123
+ pyzmq==26.4.0
124
+ referencing==0.36.2
125
+ regex==2024.11.6
126
+ requests==2.32.3
127
+ rfc3339-validator==0.1.4
128
+ rfc3986-validator==0.1.1
129
+ rich==14.0.0
130
+ rpds-py==0.24.0
131
+ rsa==4.9.1
132
+ ruff==0.11.8
133
+ safehttpx==0.1.6
134
+ safetensors==0.5.3
135
+ semantic-version==2.10.0
136
+ send2trash==1.8.3
137
+ sentencepiece==0.2.0
138
+ setuptools==80.3.1
139
+ shellingham==1.5.4
140
+ six==1.17.0
141
+ sniffio==1.3.1
142
+ soupsieve==2.7
143
+ stack-data==0.6.3
144
+ starlette==0.46.2
145
+ sympy==1.14.0
146
+ terminado==0.18.1
147
+ tinycss2==1.4.0
148
+ tokenizers==0.21.1
149
+ tomlkit==0.13.2
150
+ torch==2.7.0
151
+ tornado==6.4.2
152
+ tqdm==4.67.1
153
+ traitlets==5.14.3
154
+ transformers==4.51.3
155
+ typer==0.15.3
156
+ types-python-dateutil==2.9.0.20241206
157
+ typing-extensions==4.13.2
158
+ typing-inspection==0.4.0
159
+ tzdata==2025.2
160
+ uri-template==1.3.0
161
+ urllib3==2.4.0
162
+ uvicorn==0.34.2
163
+ wcwidth==0.2.13
164
+ webcolors==24.11.1
165
+ webencodings==0.5.1
166
+ websocket-client==1.8.0
167
+ websockets==15.0.1
168
+ xxhash==3.5.0
169
+ yarl==1.20.0
tools/schema_gemini.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "extract_menu_data",
3
+ "description": "Extract structured menu information from images.",
4
+ "parameters": {
5
+ "type": "object",
6
+ "properties": {
7
+ "restaurant": {
8
+ "type": "string",
9
+ "description": "Name of the restaurant. If the name is not available, it should be ''."
10
+ },
11
+ "address": {
12
+ "type": "string",
13
+ "description": "Address of the restaurant. If the address is not available, it should be ''."
14
+ },
15
+ "phone": {
16
+ "type": "string",
17
+ "description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
18
+ },
19
+ "business_hours": {
20
+ "type": "string",
21
+ "description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
22
+ },
23
+ "dishes": {
24
+ "type": "array",
25
+ "items": {
26
+ "type": "object",
27
+ "properties": {
28
+ "name": {
29
+ "type": "string",
30
+ "description": "Name of the menu item."
31
+ },
32
+ "price": {
33
+ "type": "number",
34
+ "format": "float",
35
+ "description": "Price of the menu item. If the price is not available, it should be -1."
36
+ }
37
+ },
38
+ "required": ["name", "price"]
39
+ },
40
+ "description": "List of menu dishes item."
41
+ }
42
+ },
43
+ "required": ["restaurant", "address", "phone", "business_hours", "dishes"]
44
+ }
45
+ }
tools/schema_openai.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "function",
3
+ "name": "extract_menu_data",
4
+ "description": "Extract structured menu information from images.",
5
+ "parameters": {
6
+ "type": "object",
7
+ "properties": {
8
+ "restaurant": {
9
+ "type": "string",
10
+ "description": "Name of the restaurant. If the name is not available, it should be ''."
11
+ },
12
+ "address": {
13
+ "type": "string",
14
+ "description": "Address of the restaurant. If the address is not available, it should be ''."
15
+ },
16
+ "phone": {
17
+ "type": "string",
18
+ "description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
19
+ },
20
+ "business_hours": {
21
+ "type": "string",
22
+ "description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
23
+ },
24
+ "dishes": {
25
+ "type": "array",
26
+ "items": {
27
+ "type": "object",
28
+ "properties": {
29
+ "name": {
30
+ "type": "string",
31
+ "description": "Name of the menu item."
32
+ },
33
+ "price": {
34
+ "type": "number",
35
+ "format": "float",
36
+ "description": "Price of the menu item. If the price is not available, it should be -1."
37
+ }
38
+ },
39
+ "required": ["name", "price"],
40
+ "additionalProperties": false
41
+ },
42
+ "description": "List of menu dishes item."
43
+ }
44
+ },
45
+ "required": ["restaurant", "address", "phone", "business_hours", "dishes"],
46
+ "additionalProperties": false
47
+ }
48
+ }
train.ipynb ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Login to HuggingFace (just login once)"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "from huggingface_hub import interpreter_login\n",
17
+ "interpreter_login()"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# Collect Menu Image Datasets\n",
25
+ "- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n",
26
+ "- After finishing, push to HuggingFace Datasets.\n",
27
+ "- For labeling:\n",
28
+ " - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n",
29
+ " - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n",
30
+ "\n",
31
+ "### Menu Type\n",
32
+ "- **h**: horizontal menu\n",
33
+ "- **v**: vertical menu\n",
34
+ "- **d**: document-style menu\n",
35
+ "- **s**: in-scene menu (non-document style)\n",
36
+ "- **i**: irregular menu (menu with irregular text layout)\n",
37
+ "\n",
38
+ "> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "from datasets import load_dataset\n",
48
+ "\n",
49
+ "dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n",
50
+ "dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "# Setup for Fine-tuning"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "from datasets import load_dataset\n",
67
+ "from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig\n",
68
+ "\n",
69
+ "from menu.donut import DonutDatasets\n",
70
+ "\n",
71
+ "DATASETS_REPO_ID = \"ryanlinjui/menu-zh-TW\" # set your dataset repo id for training\n",
72
+ "PRETRAINED_MODEL_REPO_ID = \"naver-clova-ix/donut-base\" # set your pretrained model repo id for fine-tuning\n",
73
+ "TASK_PROMPT_NAME = \"<s_menu>\" # set your task prompt name for training\n",
74
+ "MAX_LENGTH = 768 # set your max length for maximum output length\n",
75
+ "IMAGE_SIZE = [1280, 960] # set your image size for training\n",
76
+ "\n",
77
+ "raw_datasets = load_dataset(DATASETS_REPO_ID)\n",
78
+ "\n",
79
+ "# Config: set the model config\n",
80
+ "config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_REPO_ID)\n",
81
+ "config.encoder.image_size = IMAGE_SIZE\n",
82
+ "config.decoder.max_length = MAX_LENGTH\n",
83
+ "\n",
84
+ "# Processor: use the processor to process the dataset. \n",
85
+ "# Convert the image to the tensor and the text to the token ids.\n",
86
+ "processor = DonutProcessor.from_pretrained(PRETRAINED_MODEL_REPO_ID)\n",
87
+ "processor.feature_extractor.size = IMAGE_SIZE[::-1]\n",
88
+ "processor.feature_extractor.do_align_long_axis = False\n",
89
+ "\n",
90
+ "# DonutDatasets: use the DonutDatasets to process the dataset.\n",
91
+ "# For model inpit, the image must be converted to the tensor and the json text must be converted to the token with the task prompt string.\n",
92
+ "# This example sets the column name by \"image\" and \"menu\". So that image file is included in the \"image\" column and the json text is included in the \"menu\" column.\n",
93
+ "datasets = DonutDatasets(\n",
94
+ " datasets=raw_datasets,\n",
95
+ " processor=processor,\n",
96
+ " image_column=\"image\",\n",
97
+ " annotation_column=\"menu\",\n",
98
+ " task_start_token=TASK_PROMPT_NAME,\n",
99
+ " prompt_end_token=TASK_PROMPT_NAME,\n",
100
+ " train_split=0.8,\n",
101
+ " validation_split=0.1,\n",
102
+ " test_split=0.1,\n",
103
+ " sort_json_key=True,\n",
104
+ " seed=42\n",
105
+ ")\n",
106
+ "\n",
107
+ "# Model: load the pretrained model and set the config.\n",
108
+ "model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_REPO_ID, config=config)\n",
109
+ "model.decoder.resize_token_embeddings(len(processor.tokenizer))\n",
110
+ "model.config.pad_token_id = processor.tokenizer.pad_token_id\n",
111
+ "model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "metadata": {},
117
+ "source": [
118
+ "# Start Fine-tuning"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "import torch\n",
128
+ "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
129
+ "\n",
130
+ "HUGGINGFACE_MODEL_ID = \"ryanlinjui/donut-base-finetuned-menu\" # set your huggingface model repo id for saving / pushing to the hub\n",
131
+ "EPOCHS = 100 # set your training epochs\n",
132
+ "TRAIN_BATCH_SIZE = 4 # set your training batch size\n",
133
+ "\n",
134
+ "device = (\n",
135
+ " \"cuda\"\n",
136
+ " if torch.cuda.is_available()\n",
137
+ " else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
138
+ ")\n",
139
+ "print(f\"Using {device} device\")\n",
140
+ "model.to(device)\n",
141
+ "\n",
142
+ "training_args = Seq2SeqTrainingArguments(\n",
143
+ " num_train_epochs=EPOCHS,\n",
144
+ " per_device_train_batch_size=TRAIN_BATCH_SIZE,\n",
145
+ " learning_rate=3e-5,\n",
146
+ " per_device_eval_batch_size=1,\n",
147
+ " output_dir=\"./.checkpoints\",\n",
148
+ " seed=2022,\n",
149
+ " warmup_steps=30,\n",
150
+ " eval_strategy=\"steps\",\n",
151
+ " eval_steps=100,\n",
152
+ " logging_strategy=\"steps\",\n",
153
+ " logging_steps=50,\n",
154
+ " save_strategy=\"steps\",\n",
155
+ " save_steps=200,\n",
156
+ " push_to_hub=True if HUGGINGFACE_MODEL_ID else False,\n",
157
+ " hub_model_id=HUGGINGFACE_MODEL_ID,\n",
158
+ " hub_strategy=\"every_save\",\n",
159
+ " report_to=\"tensorboard\",\n",
160
+ " logging_dir=\"./.checkpoints/logs\",\n",
161
+ ")\n",
162
+ "trainer = Seq2SeqTrainer(\n",
163
+ " model=model,\n",
164
+ " args=training_args,\n",
165
+ " train_dataset=datasets[\"train\"],\n",
166
+ " eval_dataset=datasets[\"test\"],\n",
167
+ " tokenizer=processor\n",
168
+ ")\n",
169
+ "\n",
170
+ "trainer.train()"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "from transformers import (\n",
180
+ " VisionEncoderDecoderModel,\n",
181
+ " DonutProcessor,\n",
182
+ " pipeline\n",
183
+ ")\n",
184
+ "from PIL import Image\n",
185
+ "\n",
186
+ "model_id = \"ryanlinjui/donut-base-finetuned-menu\"\n",
187
+ "\n",
188
+ "# 1. 下載並載入 model + processor\n",
189
+ "processor = DonutProcessor.from_pretrained(model_id)\n",
190
+ "model = VisionEncoderDecoderModel.from_pretrained(model_id)\n",
191
+ "\n",
192
+ "# 2. 建立一個 image-to-text pipeline\n",
193
+ "ocr_pipeline = pipeline(\n",
194
+ " \"image-to-text\", # 使用 image-to-text 任務\n",
195
+ " model=model, # 傳入已載入的 model\n",
196
+ " tokenizer=processor.tokenizer,\n",
197
+ " feature_extractor=processor.feature_extractor,\n",
198
+ ")\n",
199
+ "\n",
200
+ "# 3. 載入一張測試圖片\n",
201
+ "image = Image.open(\"./examples/menu-hd.jpg\")\n",
202
+ "\n",
203
+ "# 4. 呼叫 pipeline,取得結果\n",
204
+ "outputs = ocr_pipeline(image)\n",
205
+ "\n",
206
+ "# 5. 印出辨識文字\n",
207
+ "print(outputs[0][\"generated_text\"])\n",
208
+ "\n",
209
+ "'''\n",
210
+ "# test model\n",
211
+ "import re\n",
212
+ "\n",
213
+ "from transformers import VisionEncoderDecoderModel\n",
214
+ "from transformers import DonutProcessor\n",
215
+ "import torch\n",
216
+ "from PIL import Image\n",
217
+ "\n",
218
+ "image = Image.open(\"./examples/menu-hd.jpg\").convert(\"RGB\")\n",
219
+ "\n",
220
+ "processor = DonutProcessor.from_pretrained(\"ryanlinjui/donut-base-finetuned-menu\")\n",
221
+ "model = VisionEncoderDecoderModel.from_pretrained(\"ryanlinjui/donut-base-finetuned-menu\")\n",
222
+ "device = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
223
+ "\n",
224
+ "model.eval()\n",
225
+ "model.to(device)\n",
226
+ "\n",
227
+ "pixel_values = processor(image, return_tensors=\"pt\").pixel_values\n",
228
+ "pixel_values = pixel_values.to(device)\n",
229
+ "\n",
230
+ "task_prompt = \"<s_menu>\"\n",
231
+ "decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors=\"pt\").input_ids\n",
232
+ "decoder_input_ids = decoder_input_ids.to(device)\n",
233
+ "outputs = model.generate(\n",
234
+ " pixel_values,\n",
235
+ " decoder_input_ids=decoder_input_ids,\n",
236
+ " max_length=model.decoder.config.max_position_embeddings,\n",
237
+ " early_stopping=True,\n",
238
+ " pad_token_id=processor.tokenizer.pad_token_id,\n",
239
+ " eos_token_id=processor.tokenizer.eos_token_id,\n",
240
+ " use_cache=True,\n",
241
+ " num_beams=1,\n",
242
+ " bad_words_ids=[[processor.tokenizer.unk_token_id]],\n",
243
+ " return_dict_in_generate=True,\n",
244
+ ")\n",
245
+ "\n",
246
+ "seq = processor.batch_decode(outputs.sequences)[0]\n",
247
+ "seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n",
248
+ "# seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
249
+ "seq = processor.token2json(seq)\n",
250
+ "print(seq)\n",
251
+ "'''\n"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "markdown",
256
+ "metadata": {},
257
+ "source": [
258
+ "# Plot the results"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "# Training Loss\n",
268
+ "# Validation Normal ED per each epoch 1~0, 1 -> 0.22\n",
269
+ "# Test Accuracy TED Accuracy, F1 Score Accuracy 0.687058, 0.51119 "
270
+ ]
271
+ }
272
+ ],
273
+ "metadata": {
274
+ "kernelspec": {
275
+ "display_name": ".venv",
276
+ "language": "python",
277
+ "name": "python3"
278
+ },
279
+ "language_info": {
280
+ "codemirror_mode": {
281
+ "name": "ipython",
282
+ "version": 3
283
+ },
284
+ "file_extension": ".py",
285
+ "mimetype": "text/x-python",
286
+ "name": "python",
287
+ "nbconvert_exporter": "python",
288
+ "pygments_lexer": "ipython3",
289
+ "version": "3.11.12"
290
+ }
291
+ },
292
+ "nbformat": 4,
293
+ "nbformat_minor": 2
294
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff