Nitin00043 commited on
Commit
e50f30c
·
verified ·
1 Parent(s): 4a01533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -1,38 +1,51 @@
 
1
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
2
  import gradio as gr
3
  from PIL import Image
4
 
5
- # Use a valid model identifier.
6
- # Replace "google/matcha-base" with your checkpoint if you have one.
7
  model_name = "google/matcha-base"
8
 
9
  # Load the pre-trained Pix2Struct model and processor
10
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
11
  processor = Pix2StructProcessor.from_pretrained(model_name)
12
 
13
- # Function to solve handwritten math problems
 
 
 
14
  def solve_math_problem(image):
15
- # Preprocess the image: here we render a prompt asking the model to solve the problem.
 
16
  inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt")
 
 
17
 
18
- # Generate the solution
19
- predictions = model.generate(**inputs, max_new_tokens=100)
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Decode the output
22
  solution = processor.decode(predictions[0], skip_special_tokens=True)
23
  return solution
24
 
25
- # Gradio interface
26
  demo = gr.Interface(
27
  fn=solve_math_problem,
28
  inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
29
  outputs=gr.Textbox(label="Solution"),
30
  title="Handwritten Math Problem Solver",
31
- description="Upload an image of a handwritten math problem, and the model will attempt to solve it.",
32
- examples=[
33
- ["example1.jpg"], # Add example images if available
34
- ["example2.jpg"]
35
- ],
36
  theme="soft"
37
  )
38
 
 
1
+ import torch
2
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
3
  import gradio as gr
4
  from PIL import Image
5
 
6
+ # Use a valid model identifier. Here we use "google/matcha-base".
 
7
  model_name = "google/matcha-base"
8
 
9
  # Load the pre-trained Pix2Struct model and processor
10
  model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
11
  processor = Pix2StructProcessor.from_pretrained(model_name)
12
 
13
+ # Move model to GPU if available for faster inference
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.to(device)
16
+
17
  def solve_math_problem(image):
18
+ # Preprocess the image and include a clear prompt.
19
+ # You can adjust the prompt to better match your task if needed.
20
  inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt")
21
+ # Ensure the tensors are on the same device as the model
22
+ inputs = {key: value.to(device) for key, value in inputs.items()}
23
 
24
+ # Generate the solution using beam search.
25
+ # Adjust parameters for best performance:
26
+ # - max_new_tokens: Allows longer responses.
27
+ # - num_beams: Uses beam search to explore multiple hypotheses.
28
+ # - early_stopping: Stops decoding once a complete answer is generated.
29
+ # - temperature: Controls randomness (lower value = more deterministic).
30
+ predictions = model.generate(
31
+ **inputs,
32
+ max_new_tokens=150,
33
+ num_beams=5,
34
+ early_stopping=True,
35
+ temperature=0.5
36
+ )
37
 
38
+ # Decode the output to get a string answer, skipping any special tokens.
39
  solution = processor.decode(predictions[0], skip_special_tokens=True)
40
  return solution
41
 
42
+ # Set up a Gradio interface
43
  demo = gr.Interface(
44
  fn=solve_math_problem,
45
  inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
46
  outputs=gr.Textbox(label="Solution"),
47
  title="Handwritten Math Problem Solver",
48
+ description="Upload an image of a handwritten math problem and the model will attempt to solve it.",
 
 
 
 
49
  theme="soft"
50
  )
51