NerdAI / app.py
cuneytkaya's picture
Update app.py
91761d3 verified
raw
history blame
32.1 kB
import os
import json
import gradio as gr
import gradio.themes as gr_themes # Import themes
import google.generativeai as genai
from PIL import Image
import numpy as np
from huggingface_hub import HfFolder
from dotenv import load_dotenv
import traceback
import pytesseract
import cv2
import time
# --- Load Environment Variables (Keep as is) ---
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or HfFolder.get_token("GEMINI_API_KEY")
if not GEMINI_API_KEY:
# Try to get it from Gradio secrets if running on Spaces
try:
import secrets
GEMINI_API_KEY = secrets.GEMINI_API_KEY
except (ImportError, AttributeError):
raise ValueError("Gemini API key not found. Please set the GEMINI_API_KEY environment variable or add it as a Secret if running on Hugging Face Spaces.")
if not GEMINI_API_KEY:
raise ValueError("Gemini API key not found. Please set the GEMINI_API_KEY environment variable.")
genai.configure(api_key=GEMINI_API_KEY)
# --- Define Model Names (Keep as is) ---
CLASSIFICATION_MODEL = "gemini-1.5-flash"
SOLUTION_MODEL = "gemini-1.5-pro-latest"
EXPLANATION_MODEL = "gemini-1.5-pro-latest"
SIMILAR_MODEL = "gemini-1.5-pro-latest"
MODEL_IMAGE = "gemini-1.5-pro-latest" # Using Pro for OCR
print(f"Using models: Classification: {CLASSIFICATION_MODEL}, Solution: {SOLUTION_MODEL}, Explanation: {EXPLANATION_MODEL}, Similar: {SIMILAR_MODEL}, Image Analysis: {MODEL_IMAGE}")
# --- Set Tesseract Path (Keep as is, but ensure it's correct for your env) ---
# Make sure this path is correct for your deployment environment
try:
# Check common paths
if os.path.exists('/usr/bin/tesseract'):
pytesseract.pytesseract.tesseract_cmd = '/usr/bin/tesseract'
elif os.path.exists('/opt/homebrew/bin/tesseract'): # macOS Homebrew
pytesseract.pytesseract.tesseract_cmd = '/opt/homebrew/bin/tesseract'
# Add more checks if needed (e.g., Windows)
else:
# Attempt to find Tesseract in PATH (might work in some environments)
from shutil import which
tesseract_path = which('tesseract')
if tesseract_path:
pytesseract.pytesseract.tesseract_cmd = tesseract_path
else:
print("Warning: Tesseract command not found at specified paths or in PATH. Fallback OCR might fail.")
# No exception here, let Gemini try first
except Exception as e:
print(f"Warning: Error setting Tesseract path: {e}. Fallback OCR might fail.")
# --- Backend Functions (Keep core logic, add minor logging/error handling improvements) ---
def extract_text_with_gemini(image):
"""Extract text from image using Gemini Pro Vision, with Tesseract fallback"""
extracted_text = ""
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
print("Attempting text extraction with Gemini Pro Vision...")
model = genai.GenerativeModel(MODEL_IMAGE)
prompt = """Extract ALL text, numbers, and mathematical equations from this image precisely.
Include ALL symbols, numbers, letters, and mathematical notation exactly as they appear.
Format any equations properly and maintain their layout as much as possible.
Do not add any commentary or explanation, just output the extracted text verbatim."""
response = model.generate_content([prompt, image], request_options={'timeout': 120}) # Add timeout
extracted_text = response.text.strip()
print(f"Gemini extracted text (first 100 chars): {extracted_text[:100]}...")
# Fallback condition: if Gemini returns very little text or indicates failure
if len(extracted_text) < 15 or "unable to extract" in extracted_text.lower():
print("Gemini returned limited or no text, trying Tesseract as fallback...")
raise ValueError("Gemini extraction insufficient, attempting fallback.") # Trigger fallback
return extracted_text
except Exception as e:
print(f"Gemini Extraction Error: {e}. Attempting Tesseract fallback.")
print(traceback.format_exc())
try:
if 'pytesseract' not in globals() or not hasattr(pytesseract.pytesseract, 'tesseract_cmd') or not pytesseract.pytesseract.tesseract_cmd:
print("Tesseract is not configured. Skipping fallback.")
return extracted_text if extracted_text else f"Error: Gemini failed and Tesseract is not available. Details: {str(e)}"
if isinstance(image, Image.Image):
image_array = np.array(image.convert('L')) # Convert to grayscale PIL image first
elif isinstance(image, np.ndarray):
if len(image.shape) == 3:
image_array = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
image_array = image # Assume already grayscale if 2D
else:
return f"Error: Unsupported image type for Tesseract fallback. Gemini Error: {str(e)}"
# Preprocessing for Tesseract (optional but can help)
# image_array = cv2.threshold(image_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
custom_config = r'--oem 1 --psm 6' # Assume a block of text
tesseract_text = pytesseract.image_to_string(image_array, config=custom_config, lang='eng+equ') # Add 'equ' for equations if lang pack installed
tesseract_text = tesseract_text.strip()
print(f"Tesseract extracted text (first 100 chars): {tesseract_text[:100]}...")
# Use Tesseract result only if it's significantly better than a poor Gemini result
if len(tesseract_text) > max(len(extracted_text), 20):
print("Using Tesseract result as fallback.")
return tesseract_text
elif extracted_text:
print("Keeping Gemini result despite fallback attempt.")
return extracted_text # Keep original Gemini if Tesseract wasn't better
else:
return f"Error: Both Gemini and Tesseract failed to extract sufficient text. Gemini Error: {str(e)}"
except Exception as e2:
print(f"Tesseract Fallback OCR Error: {e2}")
print(traceback.format_exc())
return extracted_text if extracted_text else f"Error: Gemini and Tesseract failed. Gemini: {str(e)}, Tesseract: {str(e2)}"
def classify_with_gemini_flash(math_problem):
"""Classify the math problem using Gemini model"""
default_classification = {
"category": "Unknown", "subtopic": "Unknown",
"difficulty": "Unknown", "key_concepts": ["Unknown"]
}
if not math_problem or len(math_problem) < 5: # Basic check
print("Skipping classification due to insufficient text.")
return default_classification
try:
print(f"Classifying problem with {CLASSIFICATION_MODEL}...")
model = genai.GenerativeModel(
model_name=CLASSIFICATION_MODEL,
generation_config={
"temperature": 0.1, "top_p": 0.95,
"max_output_tokens": 200, "response_mime_type": "application/json",
}
)
prompt = f"""
Task: Classify the following math problem precisely.
PROBLEM:
```
{math_problem}
```
Instructions:
1. Identify the Primary Math Category (e.g., Algebra, Calculus, Geometry, Trigonometry, Statistics, Number Theory, Linear Algebra, Differential Equations).
2. Determine the Specific Subtopic (e.g., Solving Linear Equations, Limits, Euclidean Geometry, Sine Rule, Normal Distribution, Prime Numbers).
3. Assess the Difficulty Level (e.g., High School - Basic, High School - Advanced, College - Introductory, College - Advanced).
4. List the Key Mathematical Concepts involved (be specific, e.g., quadratic formula, integration by parts, Pythagorean theorem, standard deviation).
Format the response STRICTLY as a JSON object with keys: "category", "subtopic", "difficulty", "key_concepts" (where key_concepts is a list of strings).
Example: {{ "category": "Algebra", "subtopic": "Quadratic Equations", "difficulty": "High School - Advanced", "key_concepts": ["quadratic formula", "discriminant", "factoring"] }}
"""
response = model.generate_content(prompt, request_options={'timeout': 60})
try:
# Clean potential markdown code fences
cleaned_text = response.text.strip().replace("```json", "").replace("```", "").strip()
classification = json.loads(cleaned_text)
# Validate structure
if not all(k in classification for k in default_classification.keys()):
print(f"Warning: Classification missing keys. Response: {cleaned_text}")
# Fill missing keys
for k, v in default_classification.items():
classification.setdefault(k, v)
if not isinstance(classification.get("key_concepts"), list):
classification["key_concepts"] = [str(classification.get("key_concepts", "Unknown"))]
print(f"Classification successful: {classification}")
return classification
except (json.JSONDecodeError, AttributeError) as json_e:
print(f"JSON Decode/Attribute Error: Unable to parse classification response: {response.text}. Error: {json_e}")
return default_classification
except Exception as e:
print(f"Classification Error: {e}")
print(traceback.format_exc())
error_classification = default_classification.copy()
error_classification["key_concepts"] = [f"Error: {str(e)}"]
return error_classification
def solve_with_gemini_pro(math_problem, classification):
"""Solve the math problem using Gemini model"""
if not math_problem or len(math_problem) < 5:
return "Cannot solve: Invalid math problem text provided."
try:
print(f"Solving problem with {SOLUTION_MODEL}...")
model = genai.GenerativeModel(
model_name=SOLUTION_MODEL,
generation_config={
"temperature": 0.2, "top_p": 0.9,
"max_output_tokens": 2000, # Increased token limit for complex solutions
}
)
# Ensure classification is a dict and format concepts
if not isinstance(classification, dict):
classification = {"category": "Unknown", "subtopic": "Unknown", "difficulty": "Unknown", "key_concepts": ["Unknown"]}
key_concepts = classification.get("key_concepts", ["Unknown"])
if isinstance(key_concepts, list):
key_concepts_str = ", ".join(key_concepts) if key_concepts else "Unknown"
else:
key_concepts_str = str(key_concepts)
prompt = f"""
Task: Solve the following mathematical problem step-by-step. Assume you are a helpful math tutor.
PROBLEM:
```
{math_problem}
```
PROBLEM CONTEXT (from classification):
- Category: {classification.get("category", "Unknown")}
- Subtopic: {classification.get("subtopic", "Unknown")}
- Difficulty: {classification.get("difficulty", "Unknown")}
- Key Concepts: {key_concepts_str}
Instructions:
1. **Understand the Goal:** Briefly state what the problem is asking for.
2. **Identify Strategy/Concepts:** Mention the main mathematical concepts or methods needed (referencing the classification if helpful).
3. **Step-by-Step Solution:** Provide a clear, numbered sequence of steps to reach the solution.
* Explain the reasoning behind each step.
* Show all necessary calculations clearly. Use LaTeX for mathematical notation where appropriate (e.g., $\\frac{{a}}{{b}}$, $x^2$, $\\int f(x) dx$). Wrap inline math in single $ and display math in double $$.
* Define any variables used.
4. **Final Answer:** Clearly state the final answer(s).
5. **Verification (Optional but Recommended):** If possible, briefly describe how the answer could be checked or verified.
6. **Conclusion/Key Takeaway:** Briefly summarize the core concept demonstrated or a key takeaway.
Format the output using Markdown for readability. Use headings, bullet points, and numbered lists effectively. Ensure LaTeX math expressions are correctly formatted.
"""
response = model.generate_content(prompt, request_options={'timeout': 180}) # Increased timeout for complex solves
print("Solution generation complete.")
# Basic check for failed generation
if not response.text or len(response.text) < 20:
print(f"Warning: Solution generation produced very short output: {response.text}")
# Add a fallback message if the response seems incomplete/failed
if "cannot solve" in response.text.lower() or "don't understand" in response.text.lower():
return response.text # Return Gemini's explicit failure message
else:
return f"Error: Solution generation failed or produced incomplete results.\n\nRaw Response:\n{response.text}"
return response.text
except Exception as e:
print(f"Solution Error: {e}")
print(traceback.format_exc())
return f"## Error Generating Solution\n\nAn error occurred while trying to solve the problem: `{str(e)}`\n\nPlease check the extracted text and try again. If the problem persists, the model might be unable to process this specific query."
def explain_solution(math_problem, solution):
"""Provide a more detailed explanation of the solution"""
if not solution or "error generating solution" in solution.lower() or "cannot solve" in solution.lower() :
return "Cannot explain: No valid solution provided."
try:
print(f"Generating detailed explanation with {EXPLANATION_MODEL}...")
model = genai.GenerativeModel(
model_name=EXPLANATION_MODEL,
generation_config={
"temperature": 0.3, "top_p": 0.95,
"max_output_tokens": 2500, # Allow more tokens for detailed explanation
}
)
prompt = f"""
Task: Provide a detailed, pedagogical explanation of the provided solution to a math problem. Assume the reader found the original solution steps difficult to follow.
ORIGINAL PROBLEM:
```
{math_problem}
```
PROVIDED SOLUTION:
```
{solution}
```
Instructions:
Elaborate on the provided solution with the goal of enhancing understanding. Focus on the 'why' behind each step.
1. **Reiterate Goal:** Briefly restate the problem's objective.
2. **Core Concepts Deep Dive:** Explain the fundamental mathematical principles mentioned or implied in the solution in more detail. Use analogies or simpler examples if helpful. Define key terms.
3. **Step-by-Step Elaboration:** Go through the solution steps again, but expand on the reasoning.
* Why was this specific step taken? What rule or theorem justifies it?
* Are there intermediate calculations or assumptions that were skipped? Spell them out.
* Address potential points of confusion.
4. **Connections:** How does this problem relate to broader mathematical ideas or prerequisite knowledge?
5. **Common Pitfalls:** Mention common mistakes students make when tackling similar problems.
6. **Alternative Perspectives (Optional):** Briefly mention if there are other valid ways to approach the problem.
Format the output using Markdown for clarity (headings, lists, bold text). Use LaTeX for math notation (inline $, display $$). Make it easy to read and digest.
"""
response = model.generate_content(prompt, request_options={'timeout': 180})
print("Detailed explanation generation complete.")
return response.text
except Exception as e:
print(f"Explanation Error: {e}")
print(traceback.format_exc())
return f"## Error Generating Explanation\n\nAn error occurred: `{str(e)}`"
def generate_similar_problems(math_problem, classification):
"""Generate similar practice math problems"""
if not math_problem or len(math_problem) < 5:
return "Cannot generate similar problems: Invalid original problem text."
try:
print(f"Generating similar problems with {SIMILAR_MODEL}...")
model = genai.GenerativeModel(
model_name=SIMILAR_MODEL,
generation_config={
"temperature": 0.7, "top_p": 0.95, # Higher temp for variety
"max_output_tokens": 1500,
}
)
# Ensure classification is a dict and format concepts
if not isinstance(classification, dict):
classification = {"category": "Unknown", "subtopic": "Unknown", "difficulty": "Unknown", "key_concepts": ["Unknown"]}
classification_str = f"""
- Category: {classification.get("category", "Unknown")}
- Subtopic: {classification.get("subtopic", "Unknown")}
- Difficulty: {classification.get("difficulty", "Unknown")}
- Key Concepts: {', '.join(classification.get("key_concepts", ["Unknown"]))}
"""
prompt = f"""
Task: Generate 3 distinct practice math problems that are similar in concept to the original problem provided, but vary slightly in presentation or difficulty.
ORIGINAL PROBLEM:
```
{math_problem}
```
CLASSIFICATION OF ORIGINAL PROBLEM:
{classification_str}
Instructions:
Create three new problems based on the original's concepts and difficulty level.
1. **Problem 1 (Similar Difficulty):** Create a problem that closely mirrors the original in terms of concepts and required steps, but uses different numbers, variables, or context.
2. **Problem 2 (Slightly Easier/Different Focus):** Create a problem that uses the same core concepts but might be slightly simpler, focus on a specific sub-step, or change the type of answer required (e.g., find an intermediate value instead of the final result).
3. **Problem 3 (Slightly Harder/Extension):** Create a problem that builds upon the original concepts, perhaps adding an extra step, combining it with another related concept, or requiring more complex manipulation.
For EACH of the 3 problems:
* Clearly state the problem question. Use LaTeX for math notation.
* Provide a one-sentence HINT on how to approach it.
* Provide the final ANSWER (just the answer, not the steps).
Format the output using Markdown. Use clear headings for each problem (e.g., "### Practice Problem 1 (Similar Difficulty)").
"""
response = model.generate_content(prompt, request_options={'timeout': 180})
print("Similar problems generation complete.")
return response.text
except Exception as e:
print(f"Similar Problems Error: {e}")
print(traceback.format_exc())
return f"## Error Generating Similar Problems\n\nAn error occurred: `{str(e)}`"
# --- Main Processing Function (Modified for better progress updates and error handling) ---
def process_image(image):
"""Main processing pipeline for the NerdAI app"""
start_time = time.time()
try:
if image is None:
return None, "Please upload an image first.", "{}", "No image provided.", "", "No image provided." # Added state output
progress(0, desc="🚀 Starting...")
time.sleep(0.5) # Give UI time to update
# Step 1: Extract text
progress(0.1, desc="🔍 Extracting text from image...")
extracted_text = extract_text_with_gemini(image)
if not extracted_text or extracted_text.startswith("Error:") or len(extracted_text) < 10 :
err_msg = extracted_text if extracted_text.startswith("Error:") else "Error: Could not extract sufficient text from the image. Please try a clearer image or check Tesseract configuration if using fallback."
print(f"Text extraction failed or insufficient: {err_msg}")
# Show the uploaded image back to the user
img_display = image if isinstance(image, Image.Image) else Image.fromarray(image)
return img_display, err_msg, "{}", err_msg, "", err_msg # Return error message in multiple fields
progress(0.4, desc=f"📊 Classifying problem ({CLASSIFICATION_MODEL})...")
classification = classify_with_gemini_flash(extracted_text)
classification_json = json.dumps(classification, indent=2)
progress(0.6, desc=f"💡 Solving problem ({SOLUTION_MODEL})...")
solution = solve_with_gemini_pro(extracted_text, classification)
end_time = time.time()
progress(1.0, desc=f"✅ Done in {end_time - start_time:.2f}s!")
# Return processed image (or original), text, classification, solution, and update state
img_display = image if isinstance(image, Image.Image) else Image.fromarray(image)
return img_display, extracted_text, classification_json, solution, extracted_text, classification_json # Pass classification JSON to state too
except Exception as e:
print(f"Process Image Error: {e}")
print(traceback.format_exc())
error_message = f"An unexpected error occurred: {str(e)}"
# Try to return the original image if possible
img_display = None
if image is not None:
img_display = image if isinstance(image, Image.Image) else Image.fromarray(image)
return img_display, error_message, "{}", error_message, "", error_message # Populate errors
# --- Gradio Interface (Major Changes Here) ---
# Custom CSS for styling
css = """
body { font-family: 'Inter', sans-serif; } /* Modern font */
.gradio-container { background-color: #f8f9fa; } /* Light background */
#title_markdown h1 {
text-align: center;
color: #4A90E2; /* Theme color */
font-weight: 600;
margin-bottom: 0px; /* Adjust spacing */
}
#subtitle_markdown p {
text-align: center;
color: #555;
margin-top: 5px; /* Adjust spacing */
margin-bottom: 20px;
}
/* Input/Output Image Area */
#input_col, #output_col { padding: 10px; }
#input_image, #processed_image {
border-radius: 8px; /* Rounded corners for images */
border: 1px solid #dee2e6;
overflow: hidden; /* Ensure border radius applies */
height: 350px; /* Fixed height */
object-fit: contain; /* Scale image nicely */
}
#input_image div[data-testid="image"], #processed_image div[data-testid="image"] {
height: 100%; /* Make inner div fill height */
}
#input_image img, #processed_image img {
height: 100%; object-fit: contain; /* Control image scaling */
}
/* Main button */
#process_button { margin-top: 15px; }
/* Output sections */
#results_group {
border: 1px solid #e9ecef;
border-radius: 8px;
padding: 15px;
background-color: #ffffff; /* White background for results */
box-shadow: 0 2px 4px rgba(0,0,0,0.05); /* Subtle shadow */
margin-top: 20px;
}
#extracted_text_output textarea, #classification_output textarea {
background-color: #f1f3f4 !important; /* Light grey background for text boxes */
border-radius: 4px;
}
#solution_output { margin-top: 15px; }
/* Action buttons below solution */
#action_buttons { margin-top: 15px; margin-bottom: 15px; }
/* Accordion styling */
.gradio-accordion > button { /* Target the accordion header button */
background-color: #eef2f6; /* Lighter header background */
border-radius: 5px 5px 0 0; /* Rounded top corners */
font-weight: 500;
}
.gradio-accordion .gradio-markdown { /* Content inside accordion */
border: 1px solid #dee2e6;
border-top: none; /* Remove top border as header has it */
padding: 15px;
border-radius: 0 0 5px 5px; /* Rounded bottom corners */
background-color: #fff; /* White background */
}
footer { visibility: hidden } /* Hide default Gradio footer */
"""
# Define a theme
# theme = gr_themes.Soft(primary_hue="blue", secondary_hue="sky")
theme = gr_themes.Default(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky).set(
# Further theme customizations if needed
button_primary_background_fill="#4A90E2",
button_primary_background_fill_hover="#357ABD",
button_secondary_background_fill="#E1E8ED",
button_secondary_background_fill_hover="#CED9E0",
block_radius="8px", # Consistent border radius
)
with gr.Blocks(theme=theme, css=css, title="NerdAI Math Solver") as demo:
# --- State Variables ---
# Store extracted text and classification needed for follow-up actions
extracted_text_state = gr.State("")
classification_state = gr.State("{}") # Store as JSON string
# --- UI Layout ---
gr.Markdown("# 🧠 NerdAI Math Problem Solver", elem_id="title_markdown")
gr.Markdown("Upload a clear image of a math problem. NerdAI will extract the text, classify it, solve it step-by-step, and offer further help!", elem_id="subtitle_markdown")
with gr.Row():
with gr.Column(scale=1, elem_id="input_col"):
input_image = gr.Image(label="Upload Math Problem", type="pil", elem_id="input_image", height=350) # Set fixed height
process_btn = gr.Button("✨ Process Image and Solve", variant="primary", elem_id="process_button")
with gr.Column(scale=1, elem_id="output_col"):
processed_image = gr.Image(label="Processed Image", interactive=False, elem_id="processed_image", height=350) # Set fixed height
# --- Results Area ---
with gr.Group(elem_id="results_group"):
gr.Markdown("### Results")
with gr.Box(): # Box for slight visual separation
extracted_text_output = gr.Textbox(label="📝 Extracted Text", lines=3, interactive=False, placeholder="Text from the image will appear here...", elem_id="extracted_text_output")
with gr.Box():
classification_output = gr.Textbox(label=f"📊 Problem Classification ({CLASSIFICATION_MODEL})", lines=5, interactive=False, placeholder="Problem type analysis will appear here...", elem_id="classification_output")
solution_output = gr.Markdown(label="✅ Solution Steps", value="*Solution steps will appear here after processing...*", elem_id="solution_output")
# --- Action Buttons ---
with gr.Row(elem_id="action_buttons"):
explain_btn = gr.Button("🤔 Explain Further", variant="secondary")
similar_btn = gr.Button("📚 Similar Questions", variant="secondary")
# --- Accordion for Detailed Outputs ---
with gr.Accordion("Detailed Explanation", open=False):
explanation_output = gr.Markdown(value="*Click 'Explain Further' above to get a detailed breakdown.*")
with gr.Accordion("Similar Practice Problems", open=False):
similar_problems_output = gr.Markdown(value="*Click 'Similar Questions' above to generate practice problems.*")
# --- Event Handlers ---
# Main process button click
process_btn.click(
fn=process_image,
inputs=[input_image],
outputs=[
processed_image,
extracted_text_output,
classification_output,
solution_output,
extracted_text_state, # Update state
classification_state # Update state
],
# api_name="process_math_image" # Optional: for API usage
)
# Explain button click
def explain_button_handler(current_problem_text, current_solution_md):
"""Handler for Explain It button using state"""
print("Explain button clicked.")
if not current_problem_text or current_problem_text.startswith("Error:") or current_problem_text == "No image provided." or current_problem_text == "Please upload an image first.":
return "Please successfully process an image first to get text and a solution."
if not current_solution_md or current_solution_md.startswith("Error") or "will appear here" in current_solution_md:
return "Cannot explain: A valid solution needs to be generated first."
# Add a loading indicator (optional, but good UX)
yield "*Generating detailed explanation... please wait.*"
explanation_result = explain_solution(current_problem_text, current_solution_md)
yield explanation_result
explain_btn.click(
fn=explain_button_handler,
inputs=[extracted_text_state, solution_output], # Use state and current solution output
outputs=explanation_output # Target the Markdown inside the Accordion
)
# Similar problems button click
def similar_button_handler(current_problem_text, current_classification_json):
"""Handler for Similar Questions button using state"""
print("Similar button clicked.")
if not current_problem_text or current_problem_text.startswith("Error:") or current_problem_text == "No image provided." or current_problem_text == "Please upload an image first.":
return "Please successfully process an image first to get the problem text and classification."
# Add a loading indicator
yield "*Generating similar problems... please wait.*"
try:
classification_dict = json.loads(current_classification_json)
# Minimal validation
if not isinstance(classification_dict, dict) or not classification_dict:
raise ValueError("Invalid classification data.")
except (json.JSONDecodeError, ValueError) as e:
print(f"Error parsing classification state for similar problems: {e}")
return f"Error: Could not use problem classification data ({e}). Please ensure the problem was classified correctly."
similar_result = generate_similar_problems(current_problem_text, classification_dict)
yield similar_result
similar_btn.click(
fn=similar_button_handler,
inputs=[extracted_text_state, classification_state], # Use state
outputs=similar_problems_output # Target the Markdown inside the Accordion
)
# Add an example image (optional)
gr.Examples(
examples=[
# Add paths to example images accessible by the script
["examples/algebra_problem.png"],
["examples/calculus_problem.jpg"],
["examples/geometry_problem.png"],
],
inputs=input_image,
# outputs=[processed_image, extracted_text_output, classification_output, solution_output, extracted_text_state, classification_state], # Outputs for examples if you want to auto-run them
# fn=process_image, # Function to run when example is clicked
cache_examples=False, # Better to re-run for dynamic models
label="Example Math Problems"
)
# --- Launch the App ---
if __name__ == "__main__":
# Create dummy example files if they don't exist for local testing
if not os.path.exists("examples"):
os.makedirs("examples")
for fname in ["algebra_problem.png", "calculus_problem.jpg", "geometry_problem.png"]:
fpath = os.path.join("examples", fname)
if not os.path.exists(fpath):
try:
# Create a simple placeholder image
img = Image.new('RGB', (200, 100), color = (73, 109, 137))
from PIL import ImageDraw
d = ImageDraw.Draw(img)
d.text((10,10), f"Placeholder for\n{fname}", fill=(255,255,0))
img.save(fpath)
print(f"Created placeholder example: {fpath}")
except Exception as e:
print(f"Could not create placeholder image {fpath}: {e}")
# Recommended: Enable queue for better handling of multiple users/long tasks
demo.queue().launch(debug=True) # debug=True for more logs, remove for production