Spaces:
Running
Running
import argparse | |
import time | |
from tqdm import tqdm | |
import evaluate | |
import random | |
import re | |
import unimernet.tasks as tasks | |
import os | |
import numpy as np | |
import torch | |
import torch.backends.cudnn as cudnn | |
from PIL import Image | |
from tabulate import tabulate | |
from rapidfuzz.distance import Levenshtein | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from unimernet.common.config import Config | |
# imports modules for registration | |
from unimernet.datasets.builders import * | |
from unimernet.models import * | |
from unimernet.processors import * | |
from unimernet.tasks import * | |
from unimernet.processors import load_processor | |
class MathDataset(Dataset): | |
def __init__(self, image_paths, transform=None): | |
self.image_paths = image_paths | |
self.transform = transform | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
raw_image = Image.open(self.image_paths[idx]) | |
if self.transform: | |
image = self.transform(raw_image) | |
return image | |
def load_data(image_path, math_file): | |
""" | |
Load a list of image paths and their corresponding formulas. | |
The function skips empty lines and lines without corresponding images. | |
Args: | |
image_path (str): The path to the directory containing the image files. | |
math_file (str): The path to the text file containing the formulas. | |
Returns: | |
list, list: A list of image paths and a list of corresponding formula | |
""" | |
image_names = [f for f in sorted(os.listdir(image_path))] | |
image_paths = [os.path.join(image_path, f) for f in image_names] | |
math_gts = [] | |
with open(math_file, 'r') as f: | |
# load maths which | |
for i, line in enumerate(f, start=1): | |
image_name = f'{i-1:07d}.png' | |
if line.strip() and image_name in image_names: | |
math_gts.append(line.strip()) | |
if len(image_paths) != len(math_gts): | |
raise ValueError("The number of images does not match the number of formulas.") | |
return image_paths, math_gts | |
def normalize_text(text): | |
"""Remove unnecessary whitespace from LaTeX code.""" | |
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' | |
letter = '[a-zA-Z]' | |
noletter = '[\W_^\d]' | |
names = [x[0].replace(' ', '') for x in re.findall(text_reg, text)] | |
text = re.sub(text_reg, lambda match: str(names.pop(0)), text) | |
news = text | |
while True: | |
text = news | |
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', text) | |
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) | |
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) | |
if news == text: | |
break | |
return text | |
def score_text(predictions, references): | |
bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1,1e8)) | |
bleu_results = bleu.compute(predictions=predictions, references=references) | |
lev_dist = [] | |
for p, r in zip(predictions, references): | |
lev_dist.append(Levenshtein.normalized_distance(p, r)) | |
return { | |
'bleu': bleu_results["bleu"], | |
'edit': sum(lev_dist) / len(lev_dist) | |
} | |
def setup_seeds(seed=3): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
cudnn.benchmark = False | |
cudnn.deterministic = True | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Training") | |
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | |
parser.add_argument("--result_path", type=str, help="Path to json file to save result to.") | |
parser.add_argument( | |
"--options", | |
nargs="+", | |
help="override some settings in the used config, the key-value pair " | |
"in xxx=yyy format will be merged into config file (deprecate), " | |
"change to --cfg-options instead.", | |
) | |
args = parser.parse_args() | |
return args | |
def main(): | |
setup_seeds() | |
# Load Model and Processor | |
start = time.time() | |
cfg = Config(parse_args()) | |
task = tasks.setup_task(cfg) | |
model = task.build_model(cfg) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) | |
model.to(device) | |
model.eval() | |
print(f'arch_name:{cfg.config.model.arch}') | |
print(f'model_type:{cfg.config.model.model_type}') | |
print(f'checkpoint:{cfg.config.model.finetuned}') | |
print(f'='*100) | |
end1 = time.time() | |
# Generate prediction with MFR model | |
print(f'Device:{device}') | |
print(f'Load model: {end1 - start:.3f}s') | |
# Load Data (image and corresponding annotations) | |
val_names = [ | |
"Simple Print Expression(SPE)", | |
"Complex Print Expression(CPE)", | |
"Screen Capture Expression(SCE)", | |
"Handwritten Expression(HWE)" | |
] | |
image_paths = [ | |
"./data/UniMER-Test/spe", | |
"./data/UniMER-Test/cpe", | |
"./data/UniMER-Test/sce", | |
"./data/UniMER-Test/hwe" | |
] | |
math_files = [ | |
"./data/UniMER-Test/spe.txt", | |
"./data/UniMER-Test/cpe.txt", | |
"./data/UniMER-Test/sce.txt", | |
"./data/UniMER-Test/hwe.txt" | |
] | |
for val_name, image_path, math_file in zip(val_names, image_paths, math_files): | |
image_list, math_gts = load_data(image_path, math_file) | |
transform = transforms.Compose([ | |
vis_processor, | |
]) | |
dataset = MathDataset(image_list, transform=transform) | |
dataloader = DataLoader(dataset, batch_size=128, num_workers=32) | |
math_preds = [] | |
for images in tqdm(dataloader): | |
images = images.to(device) | |
with torch.no_grad(): | |
output = model.generate({"image": images}) | |
math_preds.extend(output["pred_str"]) | |
# Compute BLEU/METEOR/EditDistance | |
norm_gts = [normalize_text(gt) for gt in math_gts] | |
norm_preds = [normalize_text(pred) for pred in math_preds] | |
print(f'len_gts:{len(norm_gts)}, len_preds={len(norm_preds)}') | |
print(f'norm_gts[0]:{norm_gts[0]}') | |
print(f'norm_preds[0]:{norm_preds[0]}') | |
p_scores = score_text(norm_preds, norm_gts) | |
write_data= { | |
"scores": p_scores, | |
"text": [{"prediction": p, "reference": r} for p, r in zip(norm_preds, norm_gts)] | |
} | |
score_table = [] | |
score_headers = ["bleu", "edit"] | |
score_dirs = ["⬆", "⬇"] | |
score_table.append([write_data["scores"][h] for h in score_headers]) | |
score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] | |
end2 = time.time() | |
print(f'Evaluation Set:{val_name}') | |
print(f'Inference Time: {end2 - end1}s') | |
print(tabulate(score_table, headers=[*score_headers])) | |
print('='*100) | |
if __name__ == "__main__": | |
main() | |