File size: 4,404 Bytes
328b005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import io
import torch
import gradio as gr
import numpy as np
from PIL import Image
from unimernet.common.config import Config
from unimernet.processors import load_processor
import unimernet.tasks as tasks
import argparse
import os

MAX_WIDTH = 872
MAX_HEIGHT = 1024

class ImageProcessor:
    """ImageProcessor handles model loading and image processing."""
    def __init__(self, cfg_path):
        self.cfg_path = cfg_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.vis_processor = self.load_model_and_processor()

    def load_model_and_processor(self):
        args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
        cfg = Config(args)
        task = tasks.setup_task(cfg)
        model = task.build_model(cfg).to(self.device)
        vis_processor = load_processor(
            "formula_image_eval",
            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
        )
        return model, vis_processor

    def process_single_image(self, pil_image):
        image = self.vis_processor(pil_image).unsqueeze(0).to(self.device)
        output = self.model.generate({"image": image})
        pred = output["pred_str"][0]
        return pred

# 初始化模型
cfg_path = "demo.yaml"
processor = ImageProcessor(cfg_path)

# 单张预测
def predict_single(img):
    if img is None:
        return "No image uploaded."
    img = img.convert("RGB")
    img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
    latex_code = processor.process_single_image(img)
    return latex_code

# 批量预测
def predict_batch(img_list):
    if not img_list:
        return ["No images uploaded."]
    results = []
    for img in img_list:
        if img is None:
            results.append("Invalid image")
            continue
        img = img.convert("RGB")
        img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
        latex_code = processor.process_single_image(img)
        results.append(latex_code)
    return results

# 界面搭建
title = "UniMERNet Formula Recognition"
description = "Upload an image (or multiple images) containing math formulas. The model will return LaTeX code."

with gr.Blocks(title=title) as demo:
    gr.Markdown(f"# {title}")
    gr.Markdown(description)

    with gr.Tab("Single Image Recognition"):
        with gr.Row():
            input_img = gr.Image(type="pil", label="Upload a single formula image")
            output_text = gr.Textbox(label="Predicted LaTeX code", lines=5)

        btn_single = gr.Button("Recognize Single Image")
        btn_single.click(fn=predict_single, inputs=input_img, outputs=output_text)

    # with gr.Tab("Batch Image Recognition"):
    #     with gr.Row():
    #         input_imgs = gr.File(file_types=["image"], file_count="multiple", label="Upload multiple images (png/jpg/jpeg/webp)")
    #
    #     batch_outputs = gr.Dataframe(headers=["Image", "Predicted LaTeX Code"], datatype=["str", "str"])
    #
    #     def batch_process(files):
    #         imgs = []
    #         file_names = []
    #         for file in files:
    #             with Image.open(file.name) as img:
    #                 imgs.append(img.copy())
    #                 file_names.append(os.path.basename(file.name))
    #         preds = predict_batch(imgs)
    #         return list(zip(file_names, preds))
    #
    #     btn_batch = gr.Button("Recognize Batch Images")
    #     btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs)
    with gr.Tab("Batch Image Recognition"):
        with gr.Row():
            input_imgs = gr.File(file_types=["image"], file_count="multiple",
                                 label="Upload multiple images (png/jpg/jpeg/webp)")

        batch_outputs = gr.Dataframe()  # ✅ 改这里,不加headers和datatype!


        def batch_process(files):
            imgs = []
            file_names = []
            for file in files:
                with Image.open(file.name) as img:
                    imgs.append(img.copy())
                    file_names.append(os.path.basename(file.name))
            preds = predict_batch(imgs)
            return list(zip(file_names, preds))


        btn_batch = gr.Button("Recognize Batch Images")
        btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs)

if __name__ == "__main__":
    demo.launch()