Spaces:
Running
Running
import io | |
import torch | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from unimernet.common.config import Config | |
from unimernet.processors import load_processor | |
import unimernet.tasks as tasks | |
import argparse | |
import os | |
MAX_WIDTH = 872 | |
MAX_HEIGHT = 1024 | |
class ImageProcessor: | |
"""ImageProcessor handles model loading and image processing.""" | |
def __init__(self, cfg_path): | |
self.cfg_path = cfg_path | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model, self.vis_processor = self.load_model_and_processor() | |
def load_model_and_processor(self): | |
args = argparse.Namespace(cfg_path=self.cfg_path, options=None) | |
cfg = Config(args) | |
task = tasks.setup_task(cfg) | |
model = task.build_model(cfg).to(self.device) | |
vis_processor = load_processor( | |
"formula_image_eval", | |
cfg.config.datasets.formula_rec_eval.vis_processor.eval, | |
) | |
return model, vis_processor | |
def process_single_image(self, pil_image): | |
image = self.vis_processor(pil_image).unsqueeze(0).to(self.device) | |
output = self.model.generate({"image": image}) | |
pred = output["pred_str"][0] | |
return pred | |
# 初始化模型 | |
cfg_path = "demo.yaml" | |
processor = ImageProcessor(cfg_path) | |
# 单张预测 | |
def predict_single(img): | |
if img is None: | |
return "No image uploaded." | |
img = img.convert("RGB") | |
img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) | |
latex_code = processor.process_single_image(img) | |
return latex_code | |
# 批量预测 | |
def predict_batch(img_list): | |
if not img_list: | |
return ["No images uploaded."] | |
results = [] | |
for img in img_list: | |
if img is None: | |
results.append("Invalid image") | |
continue | |
img = img.convert("RGB") | |
img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) | |
latex_code = processor.process_single_image(img) | |
results.append(latex_code) | |
return results | |
# 界面搭建 | |
title = "UniMERNet Formula Recognition" | |
description = "Upload an image (or multiple images) containing math formulas. The model will return LaTeX code." | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
with gr.Tab("Single Image Recognition"): | |
with gr.Row(): | |
input_img = gr.Image(type="pil", label="Upload a single formula image") | |
output_text = gr.Textbox(label="Predicted LaTeX code", lines=5) | |
btn_single = gr.Button("Recognize Single Image") | |
btn_single.click(fn=predict_single, inputs=input_img, outputs=output_text) | |
# with gr.Tab("Batch Image Recognition"): | |
# with gr.Row(): | |
# input_imgs = gr.File(file_types=["image"], file_count="multiple", label="Upload multiple images (png/jpg/jpeg/webp)") | |
# | |
# batch_outputs = gr.Dataframe(headers=["Image", "Predicted LaTeX Code"], datatype=["str", "str"]) | |
# | |
# def batch_process(files): | |
# imgs = [] | |
# file_names = [] | |
# for file in files: | |
# with Image.open(file.name) as img: | |
# imgs.append(img.copy()) | |
# file_names.append(os.path.basename(file.name)) | |
# preds = predict_batch(imgs) | |
# return list(zip(file_names, preds)) | |
# | |
# btn_batch = gr.Button("Recognize Batch Images") | |
# btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs) | |
with gr.Tab("Batch Image Recognition"): | |
with gr.Row(): | |
input_imgs = gr.File(file_types=["image"], file_count="multiple", | |
label="Upload multiple images (png/jpg/jpeg/webp)") | |
batch_outputs = gr.Dataframe() # ✅ 改这里,不加headers和datatype! | |
def batch_process(files): | |
imgs = [] | |
file_names = [] | |
for file in files: | |
with Image.open(file.name) as img: | |
imgs.append(img.copy()) | |
file_names.append(os.path.basename(file.name)) | |
preds = predict_batch(imgs) | |
return list(zip(file_names, preds)) | |
btn_batch = gr.Button("Recognize Batch Images") | |
btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs) | |
if __name__ == "__main__": | |
demo.launch() | |