Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
8af6af2
0
Parent(s):
Sync from https://github.com/ryanlinjui/menu-text-detection
Browse files- .checkpoints/.gitkeep +0 -0
- .env.example +3 -0
- .github/workflows/sync.yml +25 -0
- .gitignore +24 -0
- .python-version +1 -0
- LICENSE +21 -0
- README.md +65 -0
- app.py +151 -0
- menu/donut.py +225 -0
- menu/llm/__init__.py +2 -0
- menu/llm/base.py +9 -0
- menu/llm/gemini.py +36 -0
- menu/llm/openai.py +39 -0
- pyproject.toml +23 -0
- requirements.txt +169 -0
- tools/schema_gemini.json +45 -0
- tools/schema_openai.json +48 -0
- train.ipynb +294 -0
- uv.lock +0 -0
.checkpoints/.gitkeep
ADDED
File without changes
|
.env.example
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
HUGGINGFACE_TOKEN="HUGGINGFACE_TOKEN"
|
2 |
+
GIMINI_API_TOKEN="GIMINI_API_TOKEN"
|
3 |
+
OPENAI_API_TOKEN="OPENAI_API_TOKEN"
|
.github/workflows/sync.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face Spaces
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
jobs:
|
8 |
+
sync:
|
9 |
+
name: Sync
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- name: Checkout Repository
|
13 |
+
uses: actions/checkout@v4
|
14 |
+
|
15 |
+
- name: Remove bad files
|
16 |
+
run: rm -rf examples assets
|
17 |
+
|
18 |
+
- name: Sync to Hugging Face Spaces
|
19 |
+
uses: JacobLinCool/huggingface-sync@v1
|
20 |
+
with:
|
21 |
+
github: ${{ secrets.GITHUB_TOKEN }}
|
22 |
+
user: ryanlinjui # Hugging Face username or organization name
|
23 |
+
space: menu-text-detection # Hugging Face space name
|
24 |
+
token: ${{ secrets.HF_TOKEN }} # Hugging Face token
|
25 |
+
python_version: 3.11 # Python version
|
.gitignore
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mac
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
# cache
|
5 |
+
__pycache__
|
6 |
+
|
7 |
+
# datasets
|
8 |
+
datasets
|
9 |
+
|
10 |
+
# papers
|
11 |
+
docs/papers
|
12 |
+
|
13 |
+
# uv
|
14 |
+
.venv
|
15 |
+
|
16 |
+
# gradio
|
17 |
+
.gradio
|
18 |
+
|
19 |
+
# env
|
20 |
+
.env
|
21 |
+
|
22 |
+
# checkpoint
|
23 |
+
.checkpoints/*
|
24 |
+
!.checkpoints/.gitkeep
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.11
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 RyanLin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: menu text detection
|
3 |
+
emoji: 🦄
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
python_version: 3.11
|
8 |
+
short_description: Extract structured menu information from images into JSON...
|
9 |
+
tags: [ "donut","fine-tuning","image-to-text","transformer" ]
|
10 |
+
---
|
11 |
+
|
12 |
+
# Menu Text Detection System
|
13 |
+
|
14 |
+
Extract structured menu information from images into JSON using a fine-tuned Donut E2E model.
|
15 |
+
> Based on [Donut by Clova AI (ECCV ’22)](https://github.com/clovaai/donut)
|
16 |
+
|
17 |
+
<div align="center">
|
18 |
+
|
19 |
+
<img src="./assets/demo.gif" alt="demo" width="500"/><br>
|
20 |
+
|
21 |
+
[](https://huggingface.co/spaces/ryanlinjui/menu-text-detection)<br>
|
22 |
+
[](https://huggingface.co/collections/ryanlinjui/menu-text-detection-670ccf527626bb004bbfb39b)
|
23 |
+
|
24 |
+
</div>
|
25 |
+
|
26 |
+
## 🚀 Features
|
27 |
+
### Overview
|
28 |
+
Currently supports the following information from menu images:
|
29 |
+
|
30 |
+
- **Restaurant Name**
|
31 |
+
- **Business Hours**
|
32 |
+
- **Address**
|
33 |
+
- **Phone Number**
|
34 |
+
- **Dish Information**
|
35 |
+
- Name
|
36 |
+
- Price
|
37 |
+
|
38 |
+
> For the JSON schema, see [tools directory](./tools).
|
39 |
+
|
40 |
+
### Supported Methods to Extract Menu Information
|
41 |
+
- Fine-tuned Donut model
|
42 |
+
- OpenAI GPT API
|
43 |
+
- Google Gemini API
|
44 |
+
|
45 |
+
## 💻 Training / Fine-Tuning
|
46 |
+
### Setup
|
47 |
+
Use [uv](https://github.com/astral-sh/uv) to set up the development environment:
|
48 |
+
|
49 |
+
```bash
|
50 |
+
uv sync
|
51 |
+
```
|
52 |
+
|
53 |
+
### Training Script (Datasets collecting, Fine-Tuning)
|
54 |
+
Please refer [`train.ipynb`](./train.ipynb). Use Jupyter Notebook for training:
|
55 |
+
|
56 |
+
```bash
|
57 |
+
uv run jupyter-notebook
|
58 |
+
```
|
59 |
+
|
60 |
+
> For VSCode users, please install Jupyter extension, then select `.venv/bin/python` as your kernel.
|
61 |
+
|
62 |
+
### Run Demo Locally
|
63 |
+
```bash
|
64 |
+
uv run python app.py
|
65 |
+
```
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
from menu.llm import (
|
9 |
+
GeminiAPI,
|
10 |
+
OpenAIAPI
|
11 |
+
)
|
12 |
+
from menu.donut import DonutFinetuned
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
GEMINI_API_TOKEN = os.getenv("GIMINI_API_TOKEN", "")
|
16 |
+
OPENAI_API_TOKEN = os.getenv("OPENAI_API_TOKEN", "")
|
17 |
+
|
18 |
+
SOURCE_CODE_GH_URL = "https://github.com/ryanlinjui/menu-text-detection"
|
19 |
+
BADGE_URL = "https://img.shields.io/badge/GitHub_Code-Click_Here!!-default?logo=github"
|
20 |
+
|
21 |
+
GITHUB_RAW_URL = "https://raw.githubusercontent.com/ryanlinjui/menu-text-detection/main"
|
22 |
+
EXAMPLE_IMAGE_LIST = [
|
23 |
+
f"{GITHUB_RAW_URL}/examples/menu-hd.jpg",
|
24 |
+
f"{GITHUB_RAW_URL}/examples/menu-vs.jpg",
|
25 |
+
f"{GITHUB_RAW_URL}/examples/menu-si.jpg"
|
26 |
+
]
|
27 |
+
MODEL_LIST = [
|
28 |
+
"Donut Model",
|
29 |
+
"gemini-2.0-flash",
|
30 |
+
"gemini-2.5-flash-preview-04-17",
|
31 |
+
"gemini-2.5-pro-preview-03-25",
|
32 |
+
"gpt-4.1",
|
33 |
+
"gpt-4o",
|
34 |
+
"o4-mini"
|
35 |
+
]
|
36 |
+
|
37 |
+
def handle(image: np.ndarray, model: str, api_token: str) -> str:
|
38 |
+
if image is None:
|
39 |
+
raise gr.Error("Please upload an image first.")
|
40 |
+
|
41 |
+
if model == MODEL_LIST[0]:
|
42 |
+
result = DonutFinetuned.predict(image)
|
43 |
+
|
44 |
+
elif model in MODEL_LIST[1:]:
|
45 |
+
if len(api_token) < 10:
|
46 |
+
raise gr.Error(f"Please provide a valid token for {model}.")
|
47 |
+
try:
|
48 |
+
if model in MODEL_LIST[1:4]:
|
49 |
+
result = GeminiAPI.call(image, model, api_token)
|
50 |
+
else:
|
51 |
+
result = OpenAIAPI.call(image, model, api_token)
|
52 |
+
except Exception as e:
|
53 |
+
raise gr.Error(f"Failed to process with API model {model}: {str(e)}")
|
54 |
+
else:
|
55 |
+
raise gr.Error("Invalid model selection. Please choose a valid model.")
|
56 |
+
|
57 |
+
return json.dumps(result, indent=4, ensure_ascii=False)
|
58 |
+
|
59 |
+
def UserInterface() -> gr.Interface:
|
60 |
+
with gr.Blocks(
|
61 |
+
delete_cache=(86400, 86400),
|
62 |
+
css="""
|
63 |
+
.image-panel {
|
64 |
+
display: flex;
|
65 |
+
flex-direction: column;
|
66 |
+
height: 600px;
|
67 |
+
}
|
68 |
+
.image-panel img {
|
69 |
+
object-fit: contain;
|
70 |
+
max-height: 600px;
|
71 |
+
max-width: 600px;
|
72 |
+
width: 100%;
|
73 |
+
}
|
74 |
+
.large-text textarea {
|
75 |
+
font-size: 20px !important;
|
76 |
+
height: 600px !important;
|
77 |
+
width: 100% !important;
|
78 |
+
}
|
79 |
+
"""
|
80 |
+
) as gradio_interface:
|
81 |
+
gr.HTML(f'<a href="{SOURCE_CODE_GH_URL}"><img src="{BADGE_URL}" alt="GitHub Code"/></a>')
|
82 |
+
gr.Markdown("# Menu Text Detection")
|
83 |
+
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column(scale=1, min_width=500):
|
86 |
+
gr.Markdown("## 📷 Menu Image")
|
87 |
+
menu_image = gr.Image(
|
88 |
+
type="numpy",
|
89 |
+
label="Input menu image",
|
90 |
+
elem_classes="image-panel"
|
91 |
+
)
|
92 |
+
|
93 |
+
gr.Markdown("## 🤖 Model Selection")
|
94 |
+
model_choice_dropdown = gr.Dropdown(
|
95 |
+
choices=MODEL_LIST,
|
96 |
+
value=MODEL_LIST[0],
|
97 |
+
label="Select Text Detection Model"
|
98 |
+
)
|
99 |
+
|
100 |
+
api_token_textbox = gr.Textbox(
|
101 |
+
label="API Token",
|
102 |
+
placeholder="Enter your API token here...",
|
103 |
+
type="password",
|
104 |
+
visible=False
|
105 |
+
)
|
106 |
+
|
107 |
+
generate_button = gr.Button("Generate Menu Information", variant="primary")
|
108 |
+
|
109 |
+
gr.Examples(
|
110 |
+
examples=EXAMPLE_IMAGE_LIST,
|
111 |
+
inputs=menu_image,
|
112 |
+
label="Example Menu Images"
|
113 |
+
)
|
114 |
+
|
115 |
+
with gr.Column(scale=1):
|
116 |
+
gr.Markdown("## 🍽️ Menu Info")
|
117 |
+
menu_json_textbox = gr.Textbox(
|
118 |
+
label="Ouput JSON",
|
119 |
+
interactive=False,
|
120 |
+
text_align="left",
|
121 |
+
elem_classes="large-text"
|
122 |
+
)
|
123 |
+
|
124 |
+
def update_token_visibility(choice):
|
125 |
+
if choice in MODEL_LIST[1:]:
|
126 |
+
current_token = ""
|
127 |
+
if choice in MODEL_LIST[1:4]:
|
128 |
+
current_token = GEMINI_API_TOKEN
|
129 |
+
elif choice in MODEL_LIST[4:]:
|
130 |
+
current_token = OPENAI_API_TOKEN
|
131 |
+
return gr.update(visible=True, value=current_token)
|
132 |
+
else:
|
133 |
+
return gr.update(visible=False)
|
134 |
+
|
135 |
+
model_choice_dropdown.change(
|
136 |
+
fn=update_token_visibility,
|
137 |
+
inputs=model_choice_dropdown,
|
138 |
+
outputs=api_token_textbox
|
139 |
+
)
|
140 |
+
|
141 |
+
generate_button.click(
|
142 |
+
fn=handle,
|
143 |
+
inputs=[menu_image, model_choice_dropdown, api_token_textbox],
|
144 |
+
outputs=menu_json_textbox
|
145 |
+
)
|
146 |
+
|
147 |
+
return gradio_interface
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
demo = UserInterface()
|
151 |
+
demo.launch()
|
menu/donut.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from datasets import DatasetDict
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from transformers import pipeline, DonutProcessor
|
9 |
+
|
10 |
+
class DonutFinetuned:
|
11 |
+
DEFAULT_PIPELINE = pipeline(
|
12 |
+
task="image-to-text",
|
13 |
+
model="naver-clova-ix/donut-base"
|
14 |
+
)
|
15 |
+
@classmethod
|
16 |
+
def predict(cls, image: np.ndarray) -> dict:
|
17 |
+
image = Image.fromarray(image)
|
18 |
+
result = cls.DEFAULT_PIPELINE(image)
|
19 |
+
return result
|
20 |
+
|
21 |
+
class DonutDatasets:
|
22 |
+
"""
|
23 |
+
Modified from:
|
24 |
+
https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/CORD/Fine_tune_Donut_on_a_custom_dataset_(CORD)_with_PyTorch_Lightning.ipynb
|
25 |
+
|
26 |
+
Donut PyTorch Dataset Wrapper (supports train/validation/test splits)
|
27 |
+
- Dynamic field names and JSON-to-token conversion
|
28 |
+
- Returns PyTorch Datasets with __getitem__ producing tensors
|
29 |
+
- Splits controlled by train_split/validation_split/test_split
|
30 |
+
- Only single JSON annotation supported
|
31 |
+
- Supports subscripting: datasets["train"], datasets["validation"], datasets["test"]
|
32 |
+
Args:
|
33 |
+
- datasets: DatasetDict containing train/validation/test splits
|
34 |
+
- processor: DonutProcessor for image processing
|
35 |
+
- image_column: Column name for images in the dataset
|
36 |
+
- annotation_column: Column name for annotations in the dataset
|
37 |
+
- task_start_token: Token to start the task
|
38 |
+
- prompt_end_token: Token to end the prompt
|
39 |
+
- max_length: Maximum length of tokenized sequences
|
40 |
+
- train_split: Fraction of data to use for training (0.0-1.0)
|
41 |
+
- validation_split: Fraction of data to use for validation (0.0-1.0)
|
42 |
+
- test_split: Fraction of data to use for testing (0.0-1.0)
|
43 |
+
- ignore_index: Index to ignore in labels (default: -100)
|
44 |
+
- sort_json_key: Whether to sort JSON keys (default: True)
|
45 |
+
- seed: Random seed for reproducibility. If None, use OS random seed (default: None)
|
46 |
+
- shuffle: Whether to shuffle the dataset (default: True)
|
47 |
+
Returns:
|
48 |
+
- DonutDatasets object with train/validation/test splits
|
49 |
+
Example:
|
50 |
+
datasets = DonutDatasets(
|
51 |
+
datasets=dataset_dict,
|
52 |
+
processor=processor,
|
53 |
+
image_column="image",
|
54 |
+
annotation_column="annotation",
|
55 |
+
task_start_token="<s_task>",
|
56 |
+
prompt_end_token="<s_prompt>",
|
57 |
+
max_length=512,
|
58 |
+
train_split=0.8,
|
59 |
+
validation_split=0.1,
|
60 |
+
test_split=0.1
|
61 |
+
)
|
62 |
+
train_dataset = datasets["train"]
|
63 |
+
validation_dataset = datasets["validation"]
|
64 |
+
test_dataset = datasets["test"]
|
65 |
+
Note:
|
66 |
+
- The dataset must be a DatasetDict with train/validation/test splits
|
67 |
+
- The processor must be a DonutProcessor instance
|
68 |
+
- The image_column and annotation_column must exist in the dataset
|
69 |
+
- The task_start_token and prompt_end_token must be unique tokens
|
70 |
+
- The max_length should be set according to the model's maximum input length
|
71 |
+
- The ignore_index is used for padding in labels (default: -100)
|
72 |
+
- The sort_json_key option determines whether JSON keys are sorted or not
|
73 |
+
"""
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
datasets: DatasetDict,
|
77 |
+
processor: DonutProcessor,
|
78 |
+
image_column: str,
|
79 |
+
annotation_column: str,
|
80 |
+
task_start_token: str,
|
81 |
+
prompt_end_token: str,
|
82 |
+
max_length: int = 512,
|
83 |
+
train_split: float = 1.0,
|
84 |
+
validation_split: float = 0.0,
|
85 |
+
test_split: float = 0.0,
|
86 |
+
ignore_index: int = -100,
|
87 |
+
sort_json_key: bool = True,
|
88 |
+
seed: Optional[int] = None,
|
89 |
+
shuffle: bool = True
|
90 |
+
):
|
91 |
+
assert abs(train_split + validation_split + test_split - 1.0) < 1e-6, (
|
92 |
+
"train/validation/test splits must sum to 1"
|
93 |
+
)
|
94 |
+
self.processor = processor
|
95 |
+
self.tokenizer = processor.tokenizer
|
96 |
+
self.image_column = image_column
|
97 |
+
self.annotation_column = annotation_column
|
98 |
+
self.max_length = max_length
|
99 |
+
self.task_start_token = task_start_token
|
100 |
+
self.prompt_end_token = prompt_end_token or task_start_token
|
101 |
+
self.ignore_index = ignore_index
|
102 |
+
self.sort_json_key = sort_json_key
|
103 |
+
|
104 |
+
# Perform split on provided datasets
|
105 |
+
raw = datasets
|
106 |
+
parts: Dict[str, Any] = {}
|
107 |
+
if train_split < 1.0:
|
108 |
+
split1 = raw["train"].train_test_split(test_size=1 - train_split, seed=seed, shuffle=shuffle)
|
109 |
+
parts["train"] = split1["train"]
|
110 |
+
rest = split1["test"]
|
111 |
+
if validation_split > 0:
|
112 |
+
val_frac = validation_split / (validation_split + test_split)
|
113 |
+
split2 = rest.train_test_split(test_size=1 - val_frac, seed=seed, shuffle=shuffle)
|
114 |
+
parts["validation"] = split2["train"]
|
115 |
+
parts["test"] = split2["test"]
|
116 |
+
else:
|
117 |
+
parts["test"] = rest
|
118 |
+
else:
|
119 |
+
parts = dict(raw)
|
120 |
+
|
121 |
+
# Create individual split datasets
|
122 |
+
self._splits: Dict[str, Dataset] = {}
|
123 |
+
for name, ds in parts.items():
|
124 |
+
self._splits[name] = _SplitDataset(
|
125 |
+
hf_dataset=ds,
|
126 |
+
processor=self.processor,
|
127 |
+
image_column=self.image_column,
|
128 |
+
annotation_column=self.annotation_column,
|
129 |
+
max_length=self.max_length,
|
130 |
+
ignore_index=self.ignore_index,
|
131 |
+
sort_json_key=self.sort_json_key,
|
132 |
+
task_start_token=self.task_start_token,
|
133 |
+
prompt_end_token=self.prompt_end_token,
|
134 |
+
)
|
135 |
+
|
136 |
+
def __getitem__(self, split: str) -> Dataset:
|
137 |
+
"""
|
138 |
+
Return the dataset split by name, e.g., datasets["train"]
|
139 |
+
"""
|
140 |
+
if split in self._splits:
|
141 |
+
return self._splits[split]
|
142 |
+
raise KeyError(f"Unknown split '{split}'. Available splits: {list(self._splits.keys())}")
|
143 |
+
|
144 |
+
def __repr__(self):
|
145 |
+
return f"DonutDatasets(splits={list(self._splits.keys())})"
|
146 |
+
|
147 |
+
|
148 |
+
class _SplitDataset(Dataset):
|
149 |
+
"""
|
150 |
+
PyTorch Dataset for a single split, returns (pixel_values, labels, target_sequence)
|
151 |
+
"""
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
hf_dataset,
|
155 |
+
processor: DonutProcessor,
|
156 |
+
image_column: str,
|
157 |
+
annotation_column: str,
|
158 |
+
max_length: int,
|
159 |
+
ignore_index: int,
|
160 |
+
sort_json_key: bool,
|
161 |
+
task_start_token: str,
|
162 |
+
prompt_end_token: str,
|
163 |
+
):
|
164 |
+
self.processor = processor
|
165 |
+
self.tokenizer = processor.tokenizer
|
166 |
+
self.hf_dataset = hf_dataset
|
167 |
+
self.image_column = image_column
|
168 |
+
self.annotation_column = annotation_column
|
169 |
+
self.max_length = max_length
|
170 |
+
self.ignore_index = ignore_index
|
171 |
+
self.sort_json_key = sort_json_key
|
172 |
+
self.task_start_token = task_start_token
|
173 |
+
self.prompt_end_token = prompt_end_token
|
174 |
+
|
175 |
+
# Prepare tokenized ground-truth sequences (single annotation)
|
176 |
+
self.gt_token_sequences = []
|
177 |
+
for sample in self.hf_dataset:
|
178 |
+
gt = sample[self.annotation_column]
|
179 |
+
if isinstance(gt, str):
|
180 |
+
gt = json.loads(gt)
|
181 |
+
seq = self._json_to_token(gt) + self.tokenizer.eos_token
|
182 |
+
self.gt_token_sequences.append(seq)
|
183 |
+
|
184 |
+
# Add special tokens to tokenizer
|
185 |
+
self.tokenizer.add_tokens([self.task_start_token, self.prompt_end_token])
|
186 |
+
|
187 |
+
def _json_to_token(self, obj: Any) -> str:
|
188 |
+
if isinstance(obj, dict):
|
189 |
+
keys = sorted(obj.keys()) if self.sort_json_key else obj.keys()
|
190 |
+
seq = ""
|
191 |
+
for k in keys:
|
192 |
+
open_tag = f"<s_{k}>"
|
193 |
+
close_tag = f"</s_{k}>"
|
194 |
+
self.tokenizer.add_special_tokens({"additional_special_tokens": [open_tag, close_tag]})
|
195 |
+
seq += open_tag + self._json_to_token(obj[k]) + close_tag
|
196 |
+
return seq
|
197 |
+
if isinstance(obj, list):
|
198 |
+
return r"<sep/>".join(self._json_to_token(x) for x in obj)
|
199 |
+
return str(obj)
|
200 |
+
|
201 |
+
def __len__(self):
|
202 |
+
return len(self.hf_dataset)
|
203 |
+
|
204 |
+
def __getitem__(self, idx: int):
|
205 |
+
sample = self.hf_dataset[idx]
|
206 |
+
pixel_values = self.processor(sample[self.image_column], return_tensors="pt").pixel_values.squeeze()
|
207 |
+
target_seq = self.gt_token_sequences[idx]
|
208 |
+
tokens = self.tokenizer(
|
209 |
+
target_seq,
|
210 |
+
add_special_tokens=False,
|
211 |
+
max_length=self.max_length,
|
212 |
+
padding="max_length",
|
213 |
+
truncation=True,
|
214 |
+
return_tensors="pt",
|
215 |
+
)
|
216 |
+
input_ids = tokens.input_ids.squeeze(0)
|
217 |
+
labels = input_ids.clone()
|
218 |
+
labels[labels == self.tokenizer.pad_token_id] = self.ignore_index
|
219 |
+
return {
|
220 |
+
"pixel_values": pixel_values,
|
221 |
+
"input_ids": input_ids,
|
222 |
+
"attention_mask": tokens.attention_mask.squeeze(0),
|
223 |
+
"labels": labels,
|
224 |
+
"target_sequence": target_seq
|
225 |
+
}
|
menu/llm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .gemini import GeminiAPI
|
2 |
+
from .openai import OpenAIAPI
|
menu/llm/base.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class LLMBase(ABC):
|
6 |
+
@classmethod
|
7 |
+
@abstractmethod
|
8 |
+
def call(image: np.ndarray, model: str, token: str) -> dict:
|
9 |
+
raise NotImplementedError
|
menu/llm/gemini.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from google import genai
|
6 |
+
from google.genai import types
|
7 |
+
|
8 |
+
from .base import LLMBase
|
9 |
+
|
10 |
+
FUNCTION_CALL = json.load(open("tools/schema_gemini.json", "r"))
|
11 |
+
|
12 |
+
class GeminiAPI(LLMBase):
|
13 |
+
@classmethod
|
14 |
+
def call(cls, image: np.ndarray, model: str, token: str) -> dict:
|
15 |
+
client = genai.Client(api_key=token) # Initialize the client with the API key
|
16 |
+
encode_img = Image.fromarray(image) # Convert the image for the API
|
17 |
+
|
18 |
+
config = types.GenerateContentConfig(
|
19 |
+
tools=[types.Tool(function_declarations=[FUNCTION_CALL])],
|
20 |
+
tool_config={
|
21 |
+
"function_calling_config": {
|
22 |
+
"mode": "ANY",
|
23 |
+
"allowed_function_names": [FUNCTION_CALL["name"]]
|
24 |
+
}
|
25 |
+
}
|
26 |
+
)
|
27 |
+
response = client.models.generate_content(
|
28 |
+
model=model,
|
29 |
+
contents=[encode_img],
|
30 |
+
config=config
|
31 |
+
)
|
32 |
+
if response.candidates[0].content.parts[0].function_call:
|
33 |
+
function_call = response.candidates[0].content.parts[0].function_call
|
34 |
+
return function_call.args
|
35 |
+
|
36 |
+
return {}
|
menu/llm/openai.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import base64
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from openai import OpenAI
|
8 |
+
|
9 |
+
from .base import LLMBase
|
10 |
+
|
11 |
+
FUNCTION_CALL = json.load(open("tools/schema_openai.json", "r"))
|
12 |
+
|
13 |
+
class OpenAIAPI(LLMBase):
|
14 |
+
@classmethod
|
15 |
+
def call(cls, image: np.ndarray, model: str, token: str) -> dict:
|
16 |
+
client = OpenAI(api_key=token) # Initialize the client with the API key
|
17 |
+
buffer = BytesIO()
|
18 |
+
Image.fromarray(image).save(buffer, format="JPEG")
|
19 |
+
encode_img = base64.b64encode(buffer.getvalue()).decode("utf-8") # Convert the image for the API
|
20 |
+
|
21 |
+
response = client.responses.create(
|
22 |
+
model=model,
|
23 |
+
input=[
|
24 |
+
{
|
25 |
+
"role": "user",
|
26 |
+
"content": [
|
27 |
+
{
|
28 |
+
"type": "input_image",
|
29 |
+
"image_url": f"data:image/jpeg;base64,{encode_img}",
|
30 |
+
},
|
31 |
+
],
|
32 |
+
}
|
33 |
+
],
|
34 |
+
tools=[FUNCTION_CALL],
|
35 |
+
)
|
36 |
+
if response and response.output:
|
37 |
+
if hasattr(response.output[0], "arguments"):
|
38 |
+
return json.loads(response.output[0].arguments)
|
39 |
+
return {}
|
pyproject.toml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
authors = [{name = "ryanlinjui", email = "[email protected]"}]
|
3 |
+
name = "menu-text-detection"
|
4 |
+
version = "0.1.0"
|
5 |
+
description = "Extract structured menu information from images into JSON using a fine-tuned Donut E2E model."
|
6 |
+
readme = "README.md"
|
7 |
+
requires-python = ">=3.11"
|
8 |
+
dependencies = [
|
9 |
+
"accelerate>=1.6.0",
|
10 |
+
"datasets>=3.6.0",
|
11 |
+
"dotenv>=0.9.9",
|
12 |
+
"google-genai>=1.14.0",
|
13 |
+
"gradio>=5.29.0",
|
14 |
+
"huggingface-hub>=0.31.1",
|
15 |
+
"matplotlib>=3.10.1",
|
16 |
+
"notebook>=7.4.2",
|
17 |
+
"openai>=1.77.0",
|
18 |
+
"pillow>=11.2.1",
|
19 |
+
"protobuf>=6.30.2",
|
20 |
+
"sentencepiece>=0.2.0",
|
21 |
+
"tensorboardx>=2.6.2.2",
|
22 |
+
"transformers>=4.51.3",
|
23 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.6.0
|
2 |
+
aiofiles==24.1.0
|
3 |
+
aiohappyeyeballs==2.6.1
|
4 |
+
aiohttp==3.11.18
|
5 |
+
aiosignal==1.3.2
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio==4.9.0
|
8 |
+
appnope==0.1.4
|
9 |
+
argon2-cffi==23.1.0
|
10 |
+
argon2-cffi-bindings==21.2.0
|
11 |
+
arrow==1.3.0
|
12 |
+
asttokens==3.0.0
|
13 |
+
async-lru==2.0.5
|
14 |
+
attrs==25.3.0
|
15 |
+
babel==2.17.0
|
16 |
+
beautifulsoup4==4.13.4
|
17 |
+
bleach==6.2.0
|
18 |
+
cachetools==5.5.2
|
19 |
+
certifi==2025.4.26
|
20 |
+
cffi==1.17.1
|
21 |
+
charset-normalizer==3.4.2
|
22 |
+
click==8.1.8
|
23 |
+
comm==0.2.2
|
24 |
+
contourpy==1.3.2
|
25 |
+
cycler==0.12.1
|
26 |
+
datasets==3.6.0
|
27 |
+
debugpy==1.8.14
|
28 |
+
decorator==5.2.1
|
29 |
+
defusedxml==0.7.1
|
30 |
+
dill==0.3.8
|
31 |
+
distro==1.9.0
|
32 |
+
dotenv==0.9.9
|
33 |
+
executing==2.2.0
|
34 |
+
fastapi==0.115.12
|
35 |
+
fastjsonschema==2.21.1
|
36 |
+
ffmpy==0.5.0
|
37 |
+
filelock==3.18.0
|
38 |
+
fonttools==4.57.0
|
39 |
+
fqdn==1.5.1
|
40 |
+
frozenlist==1.6.0
|
41 |
+
fsspec==2025.3.0
|
42 |
+
google-auth==2.40.1
|
43 |
+
google-genai==1.14.0
|
44 |
+
gradio==5.29.0
|
45 |
+
gradio-client==1.10.0
|
46 |
+
groovy==0.1.2
|
47 |
+
h11==0.16.0
|
48 |
+
hf-xet==1.1.0
|
49 |
+
httpcore==1.0.9
|
50 |
+
httpx==0.28.1
|
51 |
+
huggingface-hub==0.31.1
|
52 |
+
idna==3.10
|
53 |
+
ipykernel==6.29.5
|
54 |
+
ipython==9.2.0
|
55 |
+
ipython-pygments-lexers==1.1.1
|
56 |
+
isoduration==20.11.0
|
57 |
+
jedi==0.19.2
|
58 |
+
jinja2==3.1.6
|
59 |
+
jiter==0.9.0
|
60 |
+
json5==0.12.0
|
61 |
+
jsonpointer==3.0.0
|
62 |
+
jsonschema==4.23.0
|
63 |
+
jsonschema-specifications==2025.4.1
|
64 |
+
jupyter-client==8.6.3
|
65 |
+
jupyter-core==5.7.2
|
66 |
+
jupyter-events==0.12.0
|
67 |
+
jupyter-lsp==2.2.5
|
68 |
+
jupyter-server==2.15.0
|
69 |
+
jupyter-server-terminals==0.5.3
|
70 |
+
jupyterlab==4.4.2
|
71 |
+
jupyterlab-pygments==0.3.0
|
72 |
+
jupyterlab-server==2.27.3
|
73 |
+
kiwisolver==1.4.8
|
74 |
+
markdown-it-py==3.0.0
|
75 |
+
markupsafe==3.0.2
|
76 |
+
matplotlib==3.10.1
|
77 |
+
matplotlib-inline==0.1.7
|
78 |
+
mdurl==0.1.2
|
79 |
+
mistune==3.1.3
|
80 |
+
mpmath==1.3.0
|
81 |
+
multidict==6.4.3
|
82 |
+
multiprocess==0.70.16
|
83 |
+
nbclient==0.10.2
|
84 |
+
nbconvert==7.16.6
|
85 |
+
nbformat==5.10.4
|
86 |
+
nest-asyncio==1.6.0
|
87 |
+
networkx==3.4.2
|
88 |
+
notebook==7.4.2
|
89 |
+
notebook-shim==0.2.4
|
90 |
+
numpy==2.2.5
|
91 |
+
openai==1.77.0
|
92 |
+
orjson==3.10.18
|
93 |
+
overrides==7.7.0
|
94 |
+
packaging==25.0
|
95 |
+
pandas==2.2.3
|
96 |
+
pandocfilters==1.5.1
|
97 |
+
parso==0.8.4
|
98 |
+
pexpect==4.9.0
|
99 |
+
pillow==11.2.1
|
100 |
+
platformdirs==4.3.8
|
101 |
+
prometheus-client==0.21.1
|
102 |
+
prompt-toolkit==3.0.51
|
103 |
+
propcache==0.3.1
|
104 |
+
protobuf==6.30.2
|
105 |
+
psutil==7.0.0
|
106 |
+
ptyprocess==0.7.0
|
107 |
+
pure-eval==0.2.3
|
108 |
+
pyarrow==20.0.0
|
109 |
+
pyasn1==0.6.1
|
110 |
+
pyasn1-modules==0.4.2
|
111 |
+
pycparser==2.22
|
112 |
+
pydantic==2.11.4
|
113 |
+
pydantic-core==2.33.2
|
114 |
+
pydub==0.25.1
|
115 |
+
pygments==2.19.1
|
116 |
+
pyparsing==3.2.3
|
117 |
+
python-dateutil==2.9.0.post0
|
118 |
+
python-dotenv==1.1.0
|
119 |
+
python-json-logger==3.3.0
|
120 |
+
python-multipart==0.0.20
|
121 |
+
pytz==2025.2
|
122 |
+
pyyaml==6.0.2
|
123 |
+
pyzmq==26.4.0
|
124 |
+
referencing==0.36.2
|
125 |
+
regex==2024.11.6
|
126 |
+
requests==2.32.3
|
127 |
+
rfc3339-validator==0.1.4
|
128 |
+
rfc3986-validator==0.1.1
|
129 |
+
rich==14.0.0
|
130 |
+
rpds-py==0.24.0
|
131 |
+
rsa==4.9.1
|
132 |
+
ruff==0.11.8
|
133 |
+
safehttpx==0.1.6
|
134 |
+
safetensors==0.5.3
|
135 |
+
semantic-version==2.10.0
|
136 |
+
send2trash==1.8.3
|
137 |
+
sentencepiece==0.2.0
|
138 |
+
setuptools==80.3.1
|
139 |
+
shellingham==1.5.4
|
140 |
+
six==1.17.0
|
141 |
+
sniffio==1.3.1
|
142 |
+
soupsieve==2.7
|
143 |
+
stack-data==0.6.3
|
144 |
+
starlette==0.46.2
|
145 |
+
sympy==1.14.0
|
146 |
+
terminado==0.18.1
|
147 |
+
tinycss2==1.4.0
|
148 |
+
tokenizers==0.21.1
|
149 |
+
tomlkit==0.13.2
|
150 |
+
torch==2.7.0
|
151 |
+
tornado==6.4.2
|
152 |
+
tqdm==4.67.1
|
153 |
+
traitlets==5.14.3
|
154 |
+
transformers==4.51.3
|
155 |
+
typer==0.15.3
|
156 |
+
types-python-dateutil==2.9.0.20241206
|
157 |
+
typing-extensions==4.13.2
|
158 |
+
typing-inspection==0.4.0
|
159 |
+
tzdata==2025.2
|
160 |
+
uri-template==1.3.0
|
161 |
+
urllib3==2.4.0
|
162 |
+
uvicorn==0.34.2
|
163 |
+
wcwidth==0.2.13
|
164 |
+
webcolors==24.11.1
|
165 |
+
webencodings==0.5.1
|
166 |
+
websocket-client==1.8.0
|
167 |
+
websockets==15.0.1
|
168 |
+
xxhash==3.5.0
|
169 |
+
yarl==1.20.0
|
tools/schema_gemini.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "extract_menu_data",
|
3 |
+
"description": "Extract structured menu information from images.",
|
4 |
+
"parameters": {
|
5 |
+
"type": "object",
|
6 |
+
"properties": {
|
7 |
+
"restaurant": {
|
8 |
+
"type": "string",
|
9 |
+
"description": "Name of the restaurant. If the name is not available, it should be ''."
|
10 |
+
},
|
11 |
+
"address": {
|
12 |
+
"type": "string",
|
13 |
+
"description": "Address of the restaurant. If the address is not available, it should be ''."
|
14 |
+
},
|
15 |
+
"phone": {
|
16 |
+
"type": "string",
|
17 |
+
"description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
|
18 |
+
},
|
19 |
+
"business_hours": {
|
20 |
+
"type": "string",
|
21 |
+
"description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
|
22 |
+
},
|
23 |
+
"dishes": {
|
24 |
+
"type": "array",
|
25 |
+
"items": {
|
26 |
+
"type": "object",
|
27 |
+
"properties": {
|
28 |
+
"name": {
|
29 |
+
"type": "string",
|
30 |
+
"description": "Name of the menu item."
|
31 |
+
},
|
32 |
+
"price": {
|
33 |
+
"type": "number",
|
34 |
+
"format": "float",
|
35 |
+
"description": "Price of the menu item. If the price is not available, it should be -1."
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"required": ["name", "price"]
|
39 |
+
},
|
40 |
+
"description": "List of menu dishes item."
|
41 |
+
}
|
42 |
+
},
|
43 |
+
"required": ["restaurant", "address", "phone", "business_hours", "dishes"]
|
44 |
+
}
|
45 |
+
}
|
tools/schema_openai.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"type": "function",
|
3 |
+
"name": "extract_menu_data",
|
4 |
+
"description": "Extract structured menu information from images.",
|
5 |
+
"parameters": {
|
6 |
+
"type": "object",
|
7 |
+
"properties": {
|
8 |
+
"restaurant": {
|
9 |
+
"type": "string",
|
10 |
+
"description": "Name of the restaurant. If the name is not available, it should be ''."
|
11 |
+
},
|
12 |
+
"address": {
|
13 |
+
"type": "string",
|
14 |
+
"description": "Address of the restaurant. If the address is not available, it should be ''."
|
15 |
+
},
|
16 |
+
"phone": {
|
17 |
+
"type": "string",
|
18 |
+
"description": "Phone number of the restaurant. If the phone number is not available, it should be ''."
|
19 |
+
},
|
20 |
+
"business_hours": {
|
21 |
+
"type": "string",
|
22 |
+
"description": "Business hours of the restaurant. If the business hours are not available, it should be ''."
|
23 |
+
},
|
24 |
+
"dishes": {
|
25 |
+
"type": "array",
|
26 |
+
"items": {
|
27 |
+
"type": "object",
|
28 |
+
"properties": {
|
29 |
+
"name": {
|
30 |
+
"type": "string",
|
31 |
+
"description": "Name of the menu item."
|
32 |
+
},
|
33 |
+
"price": {
|
34 |
+
"type": "number",
|
35 |
+
"format": "float",
|
36 |
+
"description": "Price of the menu item. If the price is not available, it should be -1."
|
37 |
+
}
|
38 |
+
},
|
39 |
+
"required": ["name", "price"],
|
40 |
+
"additionalProperties": false
|
41 |
+
},
|
42 |
+
"description": "List of menu dishes item."
|
43 |
+
}
|
44 |
+
},
|
45 |
+
"required": ["restaurant", "address", "phone", "business_hours", "dishes"],
|
46 |
+
"additionalProperties": false
|
47 |
+
}
|
48 |
+
}
|
train.ipynb
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Login to HuggingFace (just login once)"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": null,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"from huggingface_hub import interpreter_login\n",
|
17 |
+
"interpreter_login()"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {},
|
23 |
+
"source": [
|
24 |
+
"# Collect Menu Image Datasets\n",
|
25 |
+
"- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.\n",
|
26 |
+
"- After finishing, push to HuggingFace Datasets.\n",
|
27 |
+
"- For labeling:\n",
|
28 |
+
" - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).\n",
|
29 |
+
" - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).\n",
|
30 |
+
"\n",
|
31 |
+
"### Menu Type\n",
|
32 |
+
"- **h**: horizontal menu\n",
|
33 |
+
"- **v**: vertical menu\n",
|
34 |
+
"- **d**: document-style menu\n",
|
35 |
+
"- **s**: in-scene menu (non-document style)\n",
|
36 |
+
"- **i**: irregular menu (menu with irregular text layout)\n",
|
37 |
+
"\n",
|
38 |
+
"> Please see the [examples](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) for more details."
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"from datasets import load_dataset\n",
|
48 |
+
"\n",
|
49 |
+
"dataset = load_dataset(path=\"datasets/menu-zh-TW\") # load dataset from the local directory including the metadata.jsonl, images files.\n",
|
50 |
+
"dataset.push_to_hub(repo_id=\"ryanlinjui/menu-zh-TW\") # push to the huggingface dataset hub"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "markdown",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"# Setup for Fine-tuning"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": null,
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"from datasets import load_dataset\n",
|
67 |
+
"from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig\n",
|
68 |
+
"\n",
|
69 |
+
"from menu.donut import DonutDatasets\n",
|
70 |
+
"\n",
|
71 |
+
"DATASETS_REPO_ID = \"ryanlinjui/menu-zh-TW\" # set your dataset repo id for training\n",
|
72 |
+
"PRETRAINED_MODEL_REPO_ID = \"naver-clova-ix/donut-base\" # set your pretrained model repo id for fine-tuning\n",
|
73 |
+
"TASK_PROMPT_NAME = \"<s_menu>\" # set your task prompt name for training\n",
|
74 |
+
"MAX_LENGTH = 768 # set your max length for maximum output length\n",
|
75 |
+
"IMAGE_SIZE = [1280, 960] # set your image size for training\n",
|
76 |
+
"\n",
|
77 |
+
"raw_datasets = load_dataset(DATASETS_REPO_ID)\n",
|
78 |
+
"\n",
|
79 |
+
"# Config: set the model config\n",
|
80 |
+
"config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_REPO_ID)\n",
|
81 |
+
"config.encoder.image_size = IMAGE_SIZE\n",
|
82 |
+
"config.decoder.max_length = MAX_LENGTH\n",
|
83 |
+
"\n",
|
84 |
+
"# Processor: use the processor to process the dataset. \n",
|
85 |
+
"# Convert the image to the tensor and the text to the token ids.\n",
|
86 |
+
"processor = DonutProcessor.from_pretrained(PRETRAINED_MODEL_REPO_ID)\n",
|
87 |
+
"processor.feature_extractor.size = IMAGE_SIZE[::-1]\n",
|
88 |
+
"processor.feature_extractor.do_align_long_axis = False\n",
|
89 |
+
"\n",
|
90 |
+
"# DonutDatasets: use the DonutDatasets to process the dataset.\n",
|
91 |
+
"# For model inpit, the image must be converted to the tensor and the json text must be converted to the token with the task prompt string.\n",
|
92 |
+
"# This example sets the column name by \"image\" and \"menu\". So that image file is included in the \"image\" column and the json text is included in the \"menu\" column.\n",
|
93 |
+
"datasets = DonutDatasets(\n",
|
94 |
+
" datasets=raw_datasets,\n",
|
95 |
+
" processor=processor,\n",
|
96 |
+
" image_column=\"image\",\n",
|
97 |
+
" annotation_column=\"menu\",\n",
|
98 |
+
" task_start_token=TASK_PROMPT_NAME,\n",
|
99 |
+
" prompt_end_token=TASK_PROMPT_NAME,\n",
|
100 |
+
" train_split=0.8,\n",
|
101 |
+
" validation_split=0.1,\n",
|
102 |
+
" test_split=0.1,\n",
|
103 |
+
" sort_json_key=True,\n",
|
104 |
+
" seed=42\n",
|
105 |
+
")\n",
|
106 |
+
"\n",
|
107 |
+
"# Model: load the pretrained model and set the config.\n",
|
108 |
+
"model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_REPO_ID, config=config)\n",
|
109 |
+
"model.decoder.resize_token_embeddings(len(processor.tokenizer))\n",
|
110 |
+
"model.config.pad_token_id = processor.tokenizer.pad_token_id\n",
|
111 |
+
"model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "markdown",
|
116 |
+
"metadata": {},
|
117 |
+
"source": [
|
118 |
+
"# Start Fine-tuning"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": null,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"import torch\n",
|
128 |
+
"from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
|
129 |
+
"\n",
|
130 |
+
"HUGGINGFACE_MODEL_ID = \"ryanlinjui/donut-base-finetuned-menu\" # set your huggingface model repo id for saving / pushing to the hub\n",
|
131 |
+
"EPOCHS = 100 # set your training epochs\n",
|
132 |
+
"TRAIN_BATCH_SIZE = 4 # set your training batch size\n",
|
133 |
+
"\n",
|
134 |
+
"device = (\n",
|
135 |
+
" \"cuda\"\n",
|
136 |
+
" if torch.cuda.is_available()\n",
|
137 |
+
" else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
|
138 |
+
")\n",
|
139 |
+
"print(f\"Using {device} device\")\n",
|
140 |
+
"model.to(device)\n",
|
141 |
+
"\n",
|
142 |
+
"training_args = Seq2SeqTrainingArguments(\n",
|
143 |
+
" num_train_epochs=EPOCHS,\n",
|
144 |
+
" per_device_train_batch_size=TRAIN_BATCH_SIZE,\n",
|
145 |
+
" learning_rate=3e-5,\n",
|
146 |
+
" per_device_eval_batch_size=1,\n",
|
147 |
+
" output_dir=\"./.checkpoints\",\n",
|
148 |
+
" seed=2022,\n",
|
149 |
+
" warmup_steps=30,\n",
|
150 |
+
" eval_strategy=\"steps\",\n",
|
151 |
+
" eval_steps=100,\n",
|
152 |
+
" logging_strategy=\"steps\",\n",
|
153 |
+
" logging_steps=50,\n",
|
154 |
+
" save_strategy=\"steps\",\n",
|
155 |
+
" save_steps=200,\n",
|
156 |
+
" push_to_hub=True if HUGGINGFACE_MODEL_ID else False,\n",
|
157 |
+
" hub_model_id=HUGGINGFACE_MODEL_ID,\n",
|
158 |
+
" hub_strategy=\"every_save\",\n",
|
159 |
+
" report_to=\"tensorboard\",\n",
|
160 |
+
" logging_dir=\"./.checkpoints/logs\",\n",
|
161 |
+
")\n",
|
162 |
+
"trainer = Seq2SeqTrainer(\n",
|
163 |
+
" model=model,\n",
|
164 |
+
" args=training_args,\n",
|
165 |
+
" train_dataset=datasets[\"train\"],\n",
|
166 |
+
" eval_dataset=datasets[\"test\"],\n",
|
167 |
+
" tokenizer=processor\n",
|
168 |
+
")\n",
|
169 |
+
"\n",
|
170 |
+
"trainer.train()"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": null,
|
176 |
+
"metadata": {},
|
177 |
+
"outputs": [],
|
178 |
+
"source": [
|
179 |
+
"from transformers import (\n",
|
180 |
+
" VisionEncoderDecoderModel,\n",
|
181 |
+
" DonutProcessor,\n",
|
182 |
+
" pipeline\n",
|
183 |
+
")\n",
|
184 |
+
"from PIL import Image\n",
|
185 |
+
"\n",
|
186 |
+
"model_id = \"ryanlinjui/donut-base-finetuned-menu\"\n",
|
187 |
+
"\n",
|
188 |
+
"# 1. 下載並載入 model + processor\n",
|
189 |
+
"processor = DonutProcessor.from_pretrained(model_id)\n",
|
190 |
+
"model = VisionEncoderDecoderModel.from_pretrained(model_id)\n",
|
191 |
+
"\n",
|
192 |
+
"# 2. 建立一個 image-to-text pipeline\n",
|
193 |
+
"ocr_pipeline = pipeline(\n",
|
194 |
+
" \"image-to-text\", # 使用 image-to-text 任務\n",
|
195 |
+
" model=model, # 傳入已載入的 model\n",
|
196 |
+
" tokenizer=processor.tokenizer,\n",
|
197 |
+
" feature_extractor=processor.feature_extractor,\n",
|
198 |
+
")\n",
|
199 |
+
"\n",
|
200 |
+
"# 3. 載入一張測試圖片\n",
|
201 |
+
"image = Image.open(\"./examples/menu-hd.jpg\")\n",
|
202 |
+
"\n",
|
203 |
+
"# 4. 呼叫 pipeline,取得結果\n",
|
204 |
+
"outputs = ocr_pipeline(image)\n",
|
205 |
+
"\n",
|
206 |
+
"# 5. 印出辨識文字\n",
|
207 |
+
"print(outputs[0][\"generated_text\"])\n",
|
208 |
+
"\n",
|
209 |
+
"'''\n",
|
210 |
+
"# test model\n",
|
211 |
+
"import re\n",
|
212 |
+
"\n",
|
213 |
+
"from transformers import VisionEncoderDecoderModel\n",
|
214 |
+
"from transformers import DonutProcessor\n",
|
215 |
+
"import torch\n",
|
216 |
+
"from PIL import Image\n",
|
217 |
+
"\n",
|
218 |
+
"image = Image.open(\"./examples/menu-hd.jpg\").convert(\"RGB\")\n",
|
219 |
+
"\n",
|
220 |
+
"processor = DonutProcessor.from_pretrained(\"ryanlinjui/donut-base-finetuned-menu\")\n",
|
221 |
+
"model = VisionEncoderDecoderModel.from_pretrained(\"ryanlinjui/donut-base-finetuned-menu\")\n",
|
222 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"mps\"\n",
|
223 |
+
"\n",
|
224 |
+
"model.eval()\n",
|
225 |
+
"model.to(device)\n",
|
226 |
+
"\n",
|
227 |
+
"pixel_values = processor(image, return_tensors=\"pt\").pixel_values\n",
|
228 |
+
"pixel_values = pixel_values.to(device)\n",
|
229 |
+
"\n",
|
230 |
+
"task_prompt = \"<s_menu>\"\n",
|
231 |
+
"decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors=\"pt\").input_ids\n",
|
232 |
+
"decoder_input_ids = decoder_input_ids.to(device)\n",
|
233 |
+
"outputs = model.generate(\n",
|
234 |
+
" pixel_values,\n",
|
235 |
+
" decoder_input_ids=decoder_input_ids,\n",
|
236 |
+
" max_length=model.decoder.config.max_position_embeddings,\n",
|
237 |
+
" early_stopping=True,\n",
|
238 |
+
" pad_token_id=processor.tokenizer.pad_token_id,\n",
|
239 |
+
" eos_token_id=processor.tokenizer.eos_token_id,\n",
|
240 |
+
" use_cache=True,\n",
|
241 |
+
" num_beams=1,\n",
|
242 |
+
" bad_words_ids=[[processor.tokenizer.unk_token_id]],\n",
|
243 |
+
" return_dict_in_generate=True,\n",
|
244 |
+
")\n",
|
245 |
+
"\n",
|
246 |
+
"seq = processor.batch_decode(outputs.sequences)[0]\n",
|
247 |
+
"seq = seq.replace(processor.tokenizer.eos_token, \"\").replace(processor.tokenizer.pad_token, \"\")\n",
|
248 |
+
"# seq = re.sub(r\"<.*?>\", \"\", seq, count=1).strip() # remove first task start token\n",
|
249 |
+
"seq = processor.token2json(seq)\n",
|
250 |
+
"print(seq)\n",
|
251 |
+
"'''\n"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "markdown",
|
256 |
+
"metadata": {},
|
257 |
+
"source": [
|
258 |
+
"# Plot the results"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": null,
|
264 |
+
"metadata": {},
|
265 |
+
"outputs": [],
|
266 |
+
"source": [
|
267 |
+
"# Training Loss\n",
|
268 |
+
"# Validation Normal ED per each epoch 1~0, 1 -> 0.22\n",
|
269 |
+
"# Test Accuracy TED Accuracy, F1 Score Accuracy 0.687058, 0.51119 "
|
270 |
+
]
|
271 |
+
}
|
272 |
+
],
|
273 |
+
"metadata": {
|
274 |
+
"kernelspec": {
|
275 |
+
"display_name": ".venv",
|
276 |
+
"language": "python",
|
277 |
+
"name": "python3"
|
278 |
+
},
|
279 |
+
"language_info": {
|
280 |
+
"codemirror_mode": {
|
281 |
+
"name": "ipython",
|
282 |
+
"version": 3
|
283 |
+
},
|
284 |
+
"file_extension": ".py",
|
285 |
+
"mimetype": "text/x-python",
|
286 |
+
"name": "python",
|
287 |
+
"nbconvert_exporter": "python",
|
288 |
+
"pygments_lexer": "ipython3",
|
289 |
+
"version": "3.11.12"
|
290 |
+
}
|
291 |
+
},
|
292 |
+
"nbformat": 4,
|
293 |
+
"nbformat_minor": 2
|
294 |
+
}
|
uv.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|