Spaces:
Running
Running
import streamlit as st | |
from transformers import VisionEncoderDecoderModel, AutoTokenizer | |
from datasets import load_dataset, concatenate_datasets | |
from texteller.api.load import load_model, load_tokenizer | |
from texteller.api.inference import img2latex | |
from skimage.metrics import structural_similarity as ssim | |
from modules.cdm.evaluation import compute_cdm_score | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import io | |
from io import BytesIO | |
import base64 | |
import pandas as pd | |
import re | |
import os | |
import evaluate | |
import time | |
from collections import defaultdict | |
import shutil | |
# Configure Streamlit layout | |
st.set_page_config(layout="wide") | |
st.title("TeXTeller Demo: LaTeX Code Prediction from Math Images") | |
# Load model and tokenizer | |
def load_model_and_tokenizer(): | |
checkpoint = "OleehyO/TexTeller" | |
model = load_model(checkpoint) | |
tokenizer = load_tokenizer(checkpoint) | |
return model, tokenizer | |
def load_data(): | |
dataset = load_dataset("linxy/LaTeX_OCR", "small") | |
dataset = concatenate_datasets([split for split in dataset.values()]) | |
dataset = dataset.map(lambda sample: { | |
"complexity": estimate_complexity(sample["text"]), | |
"latex_length": len(sample["text"]), | |
"latex_depth": max_brace_depth(sample["text"]), | |
"text": normalize_latex(sample["text"]) | |
}) | |
return dataset | |
def load_metrics(): | |
return evaluate.load("bleu") | |
# Utilities to evaluate LaTeX complexity | |
def count_occurrences(pattern, text): | |
return len(re.findall(pattern, text)) | |
def max_brace_depth(latex): | |
depth = max_depth = 0 | |
for char in latex: | |
if char == '{': | |
depth += 1 | |
max_depth = max(max_depth, depth) | |
elif char == '}': | |
depth -= 1 | |
return max_depth | |
def estimate_complexity(latex): | |
length = len(latex) | |
depth = max_brace_depth(latex) | |
score = 0 | |
score += count_occurrences(r'\\(frac|sqrt)', latex) | |
score += count_occurrences(r'\\(sum|prod|int)', latex) * 2 | |
score += count_occurrences(r'\\(left|right|begin|end)', latex) * 2 | |
score += count_occurrences(r'\\begin\{(bmatrix|matrix|pmatrix)\}', latex) * 3 | |
greek_letters = r'\\(alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|mu|nu|xi|pi|rho|sigma|tau|upsilon|phi|chi|psi|omega|' \ | |
r'Gamma|Delta|Theta|Lambda|Xi|Pi|Sigma|Upsilon|Phi|Psi|Omega)' | |
score += count_occurrences(greek_letters, latex) * 0.5 | |
score += depth | |
score += length / 20 | |
if score < 4: | |
return "very simple" | |
elif score < 8: | |
return "simple" | |
elif score < 12: | |
return "medium" | |
elif score < 20: | |
return "complex" | |
return "very complex" | |
def normalize_latex(latex_code): | |
latex_code = latex_code.replace(" ", "").replace("\\displaystyle", "") | |
latex_code = re.sub(r"\\begin\{align\**\}", "", latex_code) | |
latex_code = re.sub(r"\\end\{align\**\}", "", latex_code) | |
return latex_code | |
def compute_ssim(image1, image2): | |
"""Calcule le SSIM entre deux images PIL""" | |
img1 = np.array(image1.convert("L")) # Convertir en niveaux de gris | |
img2 = np.array(image2.convert("L")) | |
return ssim(img1, img2) | |
# Convert LaTeX to image | |
def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200): | |
fig = plt.figure(figsize=image_size_in, dpi=dpi) | |
fig.text( | |
x=0.5, | |
y=0.5, | |
s=f"${latex_expression}$", | |
horizontalalignment="center", | |
verticalalignment="center", | |
fontsize=fontsize | |
) | |
buf = io.BytesIO() | |
plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1) | |
plt.close(fig) | |
buf.seek(0) | |
return Image.open(buf) | |
# --- Convert PIL image to base64 --- | |
def image_to_base64(pil_img: Image.Image) -> str: | |
img = pil_img.copy() | |
with BytesIO() as buffer: | |
img.save(buffer, 'png') | |
return base64.b64encode(buffer.getvalue()).decode() | |
# --- Formatter for HTML rendering --- | |
def image_formatter(pil_img: Image.Image) -> str: | |
img_b64 = image_to_base64(pil_img) | |
return f'<img src="data:image/png;base64,{img_b64}">' | |
# --- Build HTML table from dictionary --- | |
def build_html_table(metrics_dico): | |
metrics_df = pd.DataFrame(metrics_dico) | |
return metrics_df.to_html(escape=False, formatters={"CDM Image": image_formatter}) | |
model, tokenizer = load_model_and_tokenizer() | |
dataset = load_data() | |
bleu_metric = load_metrics() | |
# Section 1: Dataset Overview | |
st.markdown("---") | |
st.markdown("## π Dataset Overview") | |
st.markdown(""" | |
This demo uses the [LaTeX_OCR dataset](https://huggingface.co/datasets/linxy/LaTeX_OCR) from Hugging Face π€. | |
Below are 10 examples showing input images and their corresponding LaTeX code. | |
""") | |
# Take 10 examples | |
sample_dataset = dataset.select(range(10)) | |
# Constrain the width of the "table" to ~50% using centered columns | |
col_left, col_center, col_right = st.columns([1, 2, 1]) | |
with col_center: | |
header1, header2 = st.columns(2, border=True) | |
with header1: | |
st.markdown("<p style='text-align: center; font-size: 24px; font-weight: bold;'>Image</p>", unsafe_allow_html=True) | |
with header2: | |
st.markdown("<p style='text-align: center; font-size: 24px; font-weight: bold;'>LaTeX Code</p>", unsafe_allow_html=True) | |
for i in range(10): | |
col1, col2 = st.columns(2, border=True) | |
sample = sample_dataset[i] | |
with col1: | |
st.image(sample["image"]) | |
with col2: | |
st.markdown(f"`{sample['text']}`") | |
# ---- Section 2: Exploratory Data Analysis ---- | |
st.markdown("---") | |
st.header("π Exploratory Data Analysis") | |
st.markdown("We analyze the distribution of LaTeX expressions in terms of complexity, length, and depth.") | |
df = pd.DataFrame(dataset) | |
sns.set_theme() | |
# Layout: 3 plots in a row | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
fig, ax = plt.subplots(figsize=(3, 3)) | |
plot = sns.countplot(data=df, x="complexity", order=["very simple", "simple", "medium", "complex", "very complex"], palette="flare", ax=ax) | |
plot.set_xticklabels(plot.get_xticklabels(), rotation=45, horizontalalignment='right', fontsize=8) | |
ax.set_title("LaTeX Formula Complexity", fontsize=8) | |
ax.set_xlabel("") | |
ax.set_ylabel("Count", fontsize=8) | |
st.pyplot(fig) | |
with col2: | |
fig, ax = plt.subplots(figsize=(3, 3)) | |
sns.histplot(df["latex_length"], bins=20, kde=True, ax=ax) | |
ax.set_title("Length of LaTeX Code", fontsize=8) | |
ax.set_xlabel("Characters", fontsize=8) | |
ax.set_ylabel("Count", fontsize=8) | |
st.pyplot(fig) | |
with col3: | |
fig, ax = plt.subplots(figsize=(3, 3)) | |
sns.histplot(df["latex_depth"], bins=5, kde=True, color="forestgreen", ax=ax) | |
ax.set_title("Max Brace Depth of LaTeX Code", fontsize=8) | |
ax.set_xlabel("Depth", fontsize=8) | |
ax.set_ylabel("Count", fontsize=8) | |
st.pyplot(fig) | |
# ---- Section 3: Prediction ---- | |
st.markdown("---") | |
st.header("π TeXTeller Inference") | |
st.markdown("Upload a math image below to predict the LaTeX code using the TeXTeller model.") | |
# Radio button to select input source | |
input_option = st.radio( | |
"Choose an input method:", | |
options=["Upload your own image", "Use a sample from the dataset"], | |
horizontal=True | |
) | |
image = None | |
selected_index = None | |
if input_option == "Use a sample from the dataset": | |
selected_index = None | |
nb_cols = 5 | |
for i in range(10): # Affiche 10 images | |
if i % nb_cols == 0: | |
cols = st.columns(nb_cols, border=True) | |
col = cols[i % nb_cols] | |
with col: | |
if st.button("Select this sample", key=f"btn_{i}"): | |
selected_index = i | |
st.image(dataset[i]["image"], use_container_width=True) | |
if selected_index is not None: | |
image = dataset[selected_index]["image"] | |
elif input_option == "Upload your own image": | |
uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
image = image.convert("RGB") | |
# Once we have a valid image | |
if image: | |
st.divider() | |
st.markdown("### TeXTeller Prediction Output") | |
col1, col2, col3 = st.columns(3, border=True) | |
with col1: | |
st.image(image, caption="Input Image", use_container_width=True) | |
with st.spinner("Running TeXTeller..."): | |
try: | |
dico_result = defaultdict(list) | |
start = time.time() | |
predicted_latex = img2latex(model, tokenizer, [np.array(image)], out_format="katex")[0] | |
eval_time = time.time() - start | |
dico_result["Inference Time (s)"].append(f"{eval_time:.2f}") | |
with col2: | |
st.markdown("**Predicted LaTeX Code:**") | |
st.text_area(label="", value=predicted_latex, height=80) | |
with col3: | |
rendered_image = latex2image(predicted_latex) | |
st.image(rendered_image, caption="Rendered from Prediction", use_container_width=True) | |
if selected_index is not None: | |
ref_latex = dataset[selected_index]["text"] | |
predicted_latex = normalize_latex(predicted_latex) | |
# Compute BLEU score | |
bleu_results = bleu_metric.compute(predictions=[predicted_latex], references=[[ref_latex]]) | |
bleu_score = bleu_results['bleu'] | |
dico_result["BLEU Score"].append(bleu_score) | |
# Compute SSIM | |
pred_image = rendered_image.resize(image.size) | |
ssim_score = compute_ssim(image, pred_image) | |
dico_result["SSIM Score"].append(ssim_score) | |
# Compute CDM | |
cdm_score, cdm_recall, cdm_precision, compare_img = compute_cdm_score(ref_latex, predicted_latex) | |
dico_result["CDM Image"].append(compare_img) | |
dico_result["CDM Score"].append(cdm_score) | |
# Display metrics | |
html = build_html_table(dico_result) | |
st.markdown("### TeXTeller Metrics") | |
# CSS pour forcer le tableau Γ occuper toute la largeur | |
st.markdown(""" | |
<style> | |
table { | |
width: 100% !important; | |
} | |
th, td { | |
text-align: center !important; | |
vertical-align: middle !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown(html, unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Error during prediction: {e}") | |