File size: 1,268 Bytes
ec7d971
623c9e7
ec7d971
623c9e7
ec7d971
 
 
 
191e2cd
ec7d971
 
 
 
adc05de
ec7d971
 
adc05de
ec7d971
 
 
623c9e7
ec7d971
1ee9cdc
ec7d971
 
 
 
 
623c9e7
ec7d971
 
623c9e7
ec7d971
623c9e7
 
ec7d971
623c9e7
ec7d971
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import gradio as gr
from PIL import Image

# Load the pre-trained Pix2Struct model and processor
model_name = "google/pix2struct-mathqa-base"
model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
processor = Pix2StructProcessor.from_pretrained(model_name)

# Function to solve handwritten math problems
def solve_math_problem(image):
    # Preprocess the image
    inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt")
    
    # Generate the solution
    predictions = model.generate(**inputs, max_new_tokens=100)
    
    # Decode the output
    solution = processor.decode(predictions[0], skip_special_tokens=True)
    return solution

# Gradio interface
demo = gr.Interface(
    fn=solve_math_problem,
    inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
    outputs=gr.Textbox(label="Solution"),
    title="Handwritten Math Problem Solver",
    description="Upload an image of a handwritten math problem, and the model will solve it.",
    examples=[
        ["example1.jpg"],  # Add example images
        ["example2.jpg"]
    ],
    theme="soft"
)

# Launch the app
if __name__ == "__main__":
    demo.launch()