Nitin00043 commited on
Commit
a1c289f
·
verified ·
1 Parent(s): 33f2f42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -1,20 +1,28 @@
1
  # app.py
2
 
3
  import gradio as gr
 
4
  from PIL import Image
5
- import pytesseract
6
  import sympy
7
 
 
 
 
 
8
  def solve_math_problem(image):
9
  try:
10
- # Convert image to grayscale for better OCR performance
11
- image = image.convert("L")
 
 
 
12
 
13
- # Use pytesseract to extract text from the image
14
- problem_text = pytesseract.image_to_string(image, config='--psm 7')
 
15
 
16
  # Clean and prepare the extracted text
17
- problem_text = problem_text.strip().replace('\n', '').replace(' ', '')
18
 
19
  # Use sympy to parse and solve the equation
20
  # Handle simple arithmetic and algebraic equations
@@ -38,7 +46,7 @@ demo = gr.Interface(
38
  inputs=gr.Image(
39
  type="pil",
40
  label="Upload Handwritten Math Problem",
41
- image_mode="L" # Grayscale mode improves OCR accuracy
42
  ),
43
  outputs=gr.Markdown(),
44
  title="Handwritten Math Problem Solver",
@@ -47,8 +55,7 @@ demo = gr.Interface(
47
  ["example_addition.png"],
48
  ["example_algebra.jpg"]
49
  ],
50
- allow_flagging="never",
51
- theme="soft"
52
  )
53
 
54
  if __name__ == "__main__":
 
1
  # app.py
2
 
3
  import gradio as gr
4
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
  from PIL import Image
 
6
  import sympy
7
 
8
+ # Load the pre-trained model and processor outside the function for efficiency
9
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
10
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
11
+
12
  def solve_math_problem(image):
13
  try:
14
+ # Ensure the image is in RGB format
15
+ image = image.convert("RGB")
16
+
17
+ # Resize and normalize the image as expected by the model
18
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
19
 
20
+ # Generate the text (this extracts the handwritten equation)
21
+ generated_ids = model.generate(pixel_values)
22
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
23
 
24
  # Clean and prepare the extracted text
25
+ problem_text = generated_text.strip().replace(' ', '')
26
 
27
  # Use sympy to parse and solve the equation
28
  # Handle simple arithmetic and algebraic equations
 
46
  inputs=gr.Image(
47
  type="pil",
48
  label="Upload Handwritten Math Problem",
49
+ image_mode="RGB"
50
  ),
51
  outputs=gr.Markdown(),
52
  title="Handwritten Math Problem Solver",
 
55
  ["example_addition.png"],
56
  ["example_algebra.jpg"]
57
  ],
58
+ allow_flagging="never"
 
59
  )
60
 
61
  if __name__ == "__main__":