Nitin00043 commited on
Commit
36ec1b6
·
verified ·
1 Parent(s): e2a5fc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -4,9 +4,9 @@ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  from PIL import Image
5
  from sympy import sympify, solve, Eq, symbols
6
 
7
- # Load the math OCR model and processor
8
- processor = TrOCRProcessor.from_pretrained("nlpai-lab/mathocr-htr-base")
9
- model = VisionEncoderDecoderModel.from_pretrained("nlpai-lab/mathocr-htr-base")
10
 
11
  def predict_math_problem(image):
12
  try:
@@ -16,7 +16,7 @@ def predict_math_problem(image):
16
  generated_ids = model.generate(pixel_values)
17
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
18
 
19
- # Standardize mathematical symbols in the transcription
20
  transcription = (transcription
21
  .replace("×", "*")
22
  .replace("÷", "/")
@@ -26,13 +26,12 @@ def predict_math_problem(image):
26
  .replace("³", "**3")
27
  .replace("½", "1/2")
28
  .replace("¼", "1/4")
29
- .replace("…", "...") # Ellipsis
30
  )
31
 
32
- # Attempt to solve the mathematical problem
33
  solution = None
34
  try:
35
- # Check if the transcription is an equation (contains '=')
36
  if '=' in transcription:
37
  lhs, rhs = transcription.split('=', 1)
38
  equation = Eq(sympify(lhs.strip()), sympify(rhs.strip()))
@@ -42,29 +41,27 @@ def predict_math_problem(image):
42
  solution = solve(equation, variable)
43
  solution = f"{variable} = {solution}"
44
  else:
45
- solution = "No variables found in equation"
46
  else:
47
- # Treat as an arithmetic expression
48
- solution = sympify(transcription)
49
- solution = f"Result: {solution}"
50
  except:
51
- solution = "Invalid or unsolvable expression"
52
 
53
  return transcription, solution
54
 
55
  except Exception as e:
56
- return f"Error: {str(e)}", "Failed to process"
57
 
58
  # Create Gradio interface
59
  demo = gr.Interface(
60
  fn=predict_math_problem,
61
- inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"),
62
  outputs=[
63
- gr.Textbox(label="Transcribed Text"),
64
- gr.Textbox(label="Solution")
65
  ],
66
  title="Handwritten Math Solver",
67
- description="Upload a handwritten math problem to get its transcription and solution."
68
  )
69
 
70
  if __name__ == "__main__":
 
4
  from PIL import Image
5
  from sympy import sympify, solve, Eq, symbols
6
 
7
+ # Load the math OCR model and processor (publicly available)
8
+ processor = TrOCRProcessor.from_pretrained("lambdalabs/smicr_ocr_exp1")
9
+ model = VisionEncoderDecoderModel.from_pretrained("lambdalabs/smicr_ocr_exp1")
10
 
11
  def predict_math_problem(image):
12
  try:
 
16
  generated_ids = model.generate(pixel_values)
17
  transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
18
 
19
+ # Standardize mathematical symbols
20
  transcription = (transcription
21
  .replace("×", "*")
22
  .replace("÷", "/")
 
26
  .replace("³", "**3")
27
  .replace("½", "1/2")
28
  .replace("¼", "1/4")
29
+ .replace("…", "...")
30
  )
31
 
32
+ # Solve the mathematical expression
33
  solution = None
34
  try:
 
35
  if '=' in transcription:
36
  lhs, rhs = transcription.split('=', 1)
37
  equation = Eq(sympify(lhs.strip()), sympify(rhs.strip()))
 
41
  solution = solve(equation, variable)
42
  solution = f"{variable} = {solution}"
43
  else:
44
+ solution = "Solution: No Variables Found"
45
  else:
46
+ solution = f"Result: {sympify(transcription)}"
 
 
47
  except:
48
+ solution = "Error: Unable to Solve Expression"
49
 
50
  return transcription, solution
51
 
52
  except Exception as e:
53
+ return f"Error: {str(e)}", "Processing Failed"
54
 
55
  # Create Gradio interface
56
  demo = gr.Interface(
57
  fn=predict_math_problem,
58
+ inputs=gr.Image(type="pil", label="Handwritten Math Problem"),
59
  outputs=[
60
+ gr.Textbox(label="Transcribed Math Text"),
61
+ gr.Textbox(label="Math Solution")
62
  ],
63
  title="Handwritten Math Solver",
64
+ description="Upload a handwritten math image to get transcription and solution."
65
  )
66
 
67
  if __name__ == "__main__":