bpdev75 commited on
Commit
86e21aa
·
verified ·
1 Parent(s): 63ed070

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import VisionEncoderDecoderModel, AutoTokenizer
3
+ from texteller.models.ocr_model.utils.inference import inference as latex_inference
4
+ from texteller.models.ocr_model.utils.to_katex import to_katex
5
+ from PIL import Image
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+
10
+ # Configure Streamlit page layout
11
+ st.set_page_config(layout="wide")
12
+ st.title("TeXTeller Demo – LaTeX Code Prediction from Images")
13
+
14
+ # Load the TeXTeller model and tokenizer only once
15
+ @st.cache_resource
16
+ def load_model():
17
+ checkpoint = "OleehyO/TexTeller"
18
+ model = VisionEncoderDecoderModel.from_pretrained(checkpoint)
19
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
20
+ return model, tokenizer
21
+
22
+ model, tokenizer = load_model()
23
+
24
+ # Utility function to render LaTeX as an image
25
+ def latex2image(latex_expression, image_size_in=(3, 0.5), fontsize=16, dpi=200):
26
+ fig = plt.figure(figsize=image_size_in, dpi=dpi)
27
+ fig.text(
28
+ x=0.5,
29
+ y=0.5,
30
+ s=f"${latex_expression}$",
31
+ horizontalalignment="center",
32
+ verticalalignment="center",
33
+ fontsize=fontsize
34
+ )
35
+ buf = io.BytesIO()
36
+ plt.savefig(buf, format="PNG", bbox_inches="tight", pad_inches=0.1)
37
+ plt.close(fig)
38
+ buf.seek(0)
39
+ return Image.open(buf)
40
+
41
+ # Upload box for the user to provide an input image
42
+ uploaded_file = st.file_uploader("Upload a math image (JPG, PNG)...", type=["jpg", "jpeg", "png"])
43
+
44
+ # If an image is uploaded, process it
45
+ if uploaded_file:
46
+ # Display three columns: original image, predicted LaTeX, rendered LaTeX
47
+ col1, col2, col3 = st.columns(3)
48
+
49
+ # Load image using PIL
50
+ image = Image.open(uploaded_file)
51
+ with col1:
52
+ st.image(image, caption="Original Image", use_container_width=True)
53
+
54
+ # Perform prediction
55
+ with st.spinner("Running OCR model..."):
56
+ try:
57
+ res = latex_inference(model, tokenizer, [np.array(image)])
58
+ predicted_latex = to_katex(res[0])
59
+
60
+ # Show the predicted LaTeX string
61
+ with col2:
62
+ st.markdown("**Predicted LaTeX code:**")
63
+ st.text_area(label="", value=predicted_latex, height=80)
64
+
65
+ # Convert LaTeX string to an image and display
66
+ with col3:
67
+ pred_image = latex2image(predicted_latex)
68
+ st.image(pred_image, caption="Rendered from Prediction", use_container_width=True)
69
+
70
+ except Exception as e:
71
+ st.error(f"Error during prediction: {e}")