JohnTitor22's picture
Update app.py
328b005 verified
import io
import torch
import gradio as gr
import numpy as np
from PIL import Image
from unimernet.common.config import Config
from unimernet.processors import load_processor
import unimernet.tasks as tasks
import argparse
import os
MAX_WIDTH = 872
MAX_HEIGHT = 1024
class ImageProcessor:
"""ImageProcessor handles model loading and image processing."""
def __init__(self, cfg_path):
self.cfg_path = cfg_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.vis_processor = self.load_model_and_processor()
def load_model_and_processor(self):
args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg).to(self.device)
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
return model, vis_processor
def process_single_image(self, pil_image):
image = self.vis_processor(pil_image).unsqueeze(0).to(self.device)
output = self.model.generate({"image": image})
pred = output["pred_str"][0]
return pred
# 初始化模型
cfg_path = "demo.yaml"
processor = ImageProcessor(cfg_path)
# 单张预测
def predict_single(img):
if img is None:
return "No image uploaded."
img = img.convert("RGB")
img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
latex_code = processor.process_single_image(img)
return latex_code
# 批量预测
def predict_batch(img_list):
if not img_list:
return ["No images uploaded."]
results = []
for img in img_list:
if img is None:
results.append("Invalid image")
continue
img = img.convert("RGB")
img.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
latex_code = processor.process_single_image(img)
results.append(latex_code)
return results
# 界面搭建
title = "UniMERNet Formula Recognition"
description = "Upload an image (or multiple images) containing math formulas. The model will return LaTeX code."
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Tab("Single Image Recognition"):
with gr.Row():
input_img = gr.Image(type="pil", label="Upload a single formula image")
output_text = gr.Textbox(label="Predicted LaTeX code", lines=5)
btn_single = gr.Button("Recognize Single Image")
btn_single.click(fn=predict_single, inputs=input_img, outputs=output_text)
# with gr.Tab("Batch Image Recognition"):
# with gr.Row():
# input_imgs = gr.File(file_types=["image"], file_count="multiple", label="Upload multiple images (png/jpg/jpeg/webp)")
#
# batch_outputs = gr.Dataframe(headers=["Image", "Predicted LaTeX Code"], datatype=["str", "str"])
#
# def batch_process(files):
# imgs = []
# file_names = []
# for file in files:
# with Image.open(file.name) as img:
# imgs.append(img.copy())
# file_names.append(os.path.basename(file.name))
# preds = predict_batch(imgs)
# return list(zip(file_names, preds))
#
# btn_batch = gr.Button("Recognize Batch Images")
# btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs)
with gr.Tab("Batch Image Recognition"):
with gr.Row():
input_imgs = gr.File(file_types=["image"], file_count="multiple",
label="Upload multiple images (png/jpg/jpeg/webp)")
batch_outputs = gr.Dataframe() # ✅ 改这里,不加headers和datatype!
def batch_process(files):
imgs = []
file_names = []
for file in files:
with Image.open(file.name) as img:
imgs.append(img.copy())
file_names.append(os.path.basename(file.name))
preds = predict_batch(imgs)
return list(zip(file_names, preds))
btn_batch = gr.Button("Recognize Batch Images")
btn_batch.click(fn=batch_process, inputs=input_imgs, outputs=batch_outputs)
if __name__ == "__main__":
demo.launch()