Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import gradio.themes as gr_themes # Import themes | |
import google.generativeai as genai | |
from PIL import Image | |
import numpy as np | |
from huggingface_hub import HfFolder | |
from dotenv import load_dotenv | |
import traceback | |
import pytesseract | |
import cv2 | |
import time | |
# --- Load Environment Variables (Keep as is) --- | |
load_dotenv() | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or HfFolder.get_token("GEMINI_API_KEY") | |
if not GEMINI_API_KEY: | |
# Try to get it from Gradio secrets if running on Spaces | |
try: | |
import secrets | |
GEMINI_API_KEY = secrets.GEMINI_API_KEY | |
except (ImportError, AttributeError): | |
raise ValueError("Gemini API key not found. Please set the GEMINI_API_KEY environment variable or add it as a Secret if running on Hugging Face Spaces.") | |
if not GEMINI_API_KEY: | |
raise ValueError("Gemini API key not found. Please set the GEMINI_API_KEY environment variable.") | |
genai.configure(api_key=GEMINI_API_KEY) | |
# --- Define Model Names (Keep as is) --- | |
CLASSIFICATION_MODEL = "gemini-1.5-flash" | |
SOLUTION_MODEL = "gemini-1.5-pro-latest" | |
EXPLANATION_MODEL = "gemini-1.5-pro-latest" | |
SIMILAR_MODEL = "gemini-1.5-pro-latest" | |
MODEL_IMAGE = "gemini-1.5-pro-latest" # Using Pro for OCR | |
print(f"Using models: Classification: {CLASSIFICATION_MODEL}, Solution: {SOLUTION_MODEL}, Explanation: {EXPLANATION_MODEL}, Similar: {SIMILAR_MODEL}, Image Analysis: {MODEL_IMAGE}") | |
# --- Set Tesseract Path (Keep as is, but ensure it's correct for your env) --- | |
# Make sure this path is correct for your deployment environment | |
try: | |
# Check common paths | |
if os.path.exists('/usr/bin/tesseract'): | |
pytesseract.pytesseract.tesseract_cmd = '/usr/bin/tesseract' | |
elif os.path.exists('/opt/homebrew/bin/tesseract'): # macOS Homebrew | |
pytesseract.pytesseract.tesseract_cmd = '/opt/homebrew/bin/tesseract' | |
# Add more checks if needed (e.g., Windows) | |
else: | |
# Attempt to find Tesseract in PATH (might work in some environments) | |
from shutil import which | |
tesseract_path = which('tesseract') | |
if tesseract_path: | |
pytesseract.pytesseract.tesseract_cmd = tesseract_path | |
else: | |
print("Warning: Tesseract command not found at specified paths or in PATH. Fallback OCR might fail.") | |
# No exception here, let Gemini try first | |
except Exception as e: | |
print(f"Warning: Error setting Tesseract path: {e}. Fallback OCR might fail.") | |
# --- Backend Functions (Keep core logic, add minor logging/error handling improvements) --- | |
def extract_text_with_gemini(image): | |
"""Extract text from image using Gemini Pro Vision, with Tesseract fallback""" | |
extracted_text = "" | |
try: | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
print("Attempting text extraction with Gemini Pro Vision...") | |
model = genai.GenerativeModel(MODEL_IMAGE) | |
prompt = """Extract ALL text, numbers, and mathematical equations from this image precisely. | |
Include ALL symbols, numbers, letters, and mathematical notation exactly as they appear. | |
Format any equations properly and maintain their layout as much as possible. | |
Do not add any commentary or explanation, just output the extracted text verbatim.""" | |
response = model.generate_content([prompt, image], request_options={'timeout': 120}) # Add timeout | |
extracted_text = response.text.strip() | |
print(f"Gemini extracted text (first 100 chars): {extracted_text[:100]}...") | |
# Fallback condition: if Gemini returns very little text or indicates failure | |
if len(extracted_text) < 15 or "unable to extract" in extracted_text.lower(): | |
print("Gemini returned limited or no text, trying Tesseract as fallback...") | |
raise ValueError("Gemini extraction insufficient, attempting fallback.") # Trigger fallback | |
return extracted_text | |
except Exception as e: | |
print(f"Gemini Extraction Error: {e}. Attempting Tesseract fallback.") | |
print(traceback.format_exc()) | |
try: | |
if 'pytesseract' not in globals() or not hasattr(pytesseract.pytesseract, 'tesseract_cmd') or not pytesseract.pytesseract.tesseract_cmd: | |
print("Tesseract is not configured. Skipping fallback.") | |
return extracted_text if extracted_text else f"Error: Gemini failed and Tesseract is not available. Details: {str(e)}" | |
if isinstance(image, Image.Image): | |
image_array = np.array(image.convert('L')) # Convert to grayscale PIL image first | |
elif isinstance(image, np.ndarray): | |
if len(image.shape) == 3: | |
image_array = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
else: | |
image_array = image # Assume already grayscale if 2D | |
else: | |
return f"Error: Unsupported image type for Tesseract fallback. Gemini Error: {str(e)}" | |
# Preprocessing for Tesseract (optional but can help) | |
# image_array = cv2.threshold(image_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] | |
custom_config = r'--oem 1 --psm 6' # Assume a block of text | |
tesseract_text = pytesseract.image_to_string(image_array, config=custom_config, lang='eng+equ') # Add 'equ' for equations if lang pack installed | |
tesseract_text = tesseract_text.strip() | |
print(f"Tesseract extracted text (first 100 chars): {tesseract_text[:100]}...") | |
# Use Tesseract result only if it's significantly better than a poor Gemini result | |
if len(tesseract_text) > max(len(extracted_text), 20): | |
print("Using Tesseract result as fallback.") | |
return tesseract_text | |
elif extracted_text: | |
print("Keeping Gemini result despite fallback attempt.") | |
return extracted_text # Keep original Gemini if Tesseract wasn't better | |
else: | |
return f"Error: Both Gemini and Tesseract failed to extract sufficient text. Gemini Error: {str(e)}" | |
except Exception as e2: | |
print(f"Tesseract Fallback OCR Error: {e2}") | |
print(traceback.format_exc()) | |
return extracted_text if extracted_text else f"Error: Gemini and Tesseract failed. Gemini: {str(e)}, Tesseract: {str(e2)}" | |
def classify_with_gemini_flash(math_problem): | |
"""Classify the math problem using Gemini model""" | |
default_classification = { | |
"category": "Unknown", "subtopic": "Unknown", | |
"difficulty": "Unknown", "key_concepts": ["Unknown"] | |
} | |
if not math_problem or len(math_problem) < 5: # Basic check | |
print("Skipping classification due to insufficient text.") | |
return default_classification | |
try: | |
print(f"Classifying problem with {CLASSIFICATION_MODEL}...") | |
model = genai.GenerativeModel( | |
model_name=CLASSIFICATION_MODEL, | |
generation_config={ | |
"temperature": 0.1, "top_p": 0.95, | |
"max_output_tokens": 200, "response_mime_type": "application/json", | |
} | |
) | |
prompt = f""" | |
Task: Classify the following math problem precisely. | |
PROBLEM: | |
``` | |
{math_problem} | |
``` | |
Instructions: | |
1. Identify the Primary Math Category (e.g., Algebra, Calculus, Geometry, Trigonometry, Statistics, Number Theory, Linear Algebra, Differential Equations). | |
2. Determine the Specific Subtopic (e.g., Solving Linear Equations, Limits, Euclidean Geometry, Sine Rule, Normal Distribution, Prime Numbers). | |
3. Assess the Difficulty Level (e.g., High School - Basic, High School - Advanced, College - Introductory, College - Advanced). | |
4. List the Key Mathematical Concepts involved (be specific, e.g., quadratic formula, integration by parts, Pythagorean theorem, standard deviation). | |
Format the response STRICTLY as a JSON object with keys: "category", "subtopic", "difficulty", "key_concepts" (where key_concepts is a list of strings). | |
Example: {{ "category": "Algebra", "subtopic": "Quadratic Equations", "difficulty": "High School - Advanced", "key_concepts": ["quadratic formula", "discriminant", "factoring"] }} | |
""" | |
response = model.generate_content(prompt, request_options={'timeout': 60}) | |
try: | |
# Clean potential markdown code fences | |
cleaned_text = response.text.strip().replace("```json", "").replace("```", "").strip() | |
classification = json.loads(cleaned_text) | |
# Validate structure | |
if not all(k in classification for k in default_classification.keys()): | |
print(f"Warning: Classification missing keys. Response: {cleaned_text}") | |
# Fill missing keys | |
for k, v in default_classification.items(): | |
classification.setdefault(k, v) | |
if not isinstance(classification.get("key_concepts"), list): | |
classification["key_concepts"] = [str(classification.get("key_concepts", "Unknown"))] | |
print(f"Classification successful: {classification}") | |
return classification | |
except (json.JSONDecodeError, AttributeError) as json_e: | |
print(f"JSON Decode/Attribute Error: Unable to parse classification response: {response.text}. Error: {json_e}") | |
return default_classification | |
except Exception as e: | |
print(f"Classification Error: {e}") | |
print(traceback.format_exc()) | |
error_classification = default_classification.copy() | |
error_classification["key_concepts"] = [f"Error: {str(e)}"] | |
return error_classification | |
def solve_with_gemini_pro(math_problem, classification): | |
"""Solve the math problem using Gemini model""" | |
if not math_problem or len(math_problem) < 5: | |
return "Cannot solve: Invalid math problem text provided." | |
try: | |
print(f"Solving problem with {SOLUTION_MODEL}...") | |
model = genai.GenerativeModel( | |
model_name=SOLUTION_MODEL, | |
generation_config={ | |
"temperature": 0.2, "top_p": 0.9, | |
"max_output_tokens": 2000, # Increased token limit for complex solutions | |
} | |
) | |
# Ensure classification is a dict and format concepts | |
if not isinstance(classification, dict): | |
classification = {"category": "Unknown", "subtopic": "Unknown", "difficulty": "Unknown", "key_concepts": ["Unknown"]} | |
key_concepts = classification.get("key_concepts", ["Unknown"]) | |
if isinstance(key_concepts, list): | |
key_concepts_str = ", ".join(key_concepts) if key_concepts else "Unknown" | |
else: | |
key_concepts_str = str(key_concepts) | |
prompt = f""" | |
Task: Solve the following mathematical problem step-by-step. Assume you are a helpful math tutor. | |
PROBLEM: | |
``` | |
{math_problem} | |
``` | |
PROBLEM CONTEXT (from classification): | |
- Category: {classification.get("category", "Unknown")} | |
- Subtopic: {classification.get("subtopic", "Unknown")} | |
- Difficulty: {classification.get("difficulty", "Unknown")} | |
- Key Concepts: {key_concepts_str} | |
Instructions: | |
1. **Understand the Goal:** Briefly state what the problem is asking for. | |
2. **Identify Strategy/Concepts:** Mention the main mathematical concepts or methods needed (referencing the classification if helpful). | |
3. **Step-by-Step Solution:** Provide a clear, numbered sequence of steps to reach the solution. | |
* Explain the reasoning behind each step. | |
* Show all necessary calculations clearly. Use LaTeX for mathematical notation where appropriate (e.g., $\\frac{{a}}{{b}}$, $x^2$, $\\int f(x) dx$). Wrap inline math in single $ and display math in double $$. | |
* Define any variables used. | |
4. **Final Answer:** Clearly state the final answer(s). | |
5. **Verification (Optional but Recommended):** If possible, briefly describe how the answer could be checked or verified. | |
6. **Conclusion/Key Takeaway:** Briefly summarize the core concept demonstrated or a key takeaway. | |
Format the output using Markdown for readability. Use headings, bullet points, and numbered lists effectively. Ensure LaTeX math expressions are correctly formatted. | |
""" | |
response = model.generate_content(prompt, request_options={'timeout': 180}) # Increased timeout for complex solves | |
print("Solution generation complete.") | |
# Basic check for failed generation | |
if not response.text or len(response.text) < 20: | |
print(f"Warning: Solution generation produced very short output: {response.text}") | |
# Add a fallback message if the response seems incomplete/failed | |
if "cannot solve" in response.text.lower() or "don't understand" in response.text.lower(): | |
return response.text # Return Gemini's explicit failure message | |
else: | |
return f"Error: Solution generation failed or produced incomplete results.\n\nRaw Response:\n{response.text}" | |
return response.text | |
except Exception as e: | |
print(f"Solution Error: {e}") | |
print(traceback.format_exc()) | |
return f"## Error Generating Solution\n\nAn error occurred while trying to solve the problem: `{str(e)}`\n\nPlease check the extracted text and try again. If the problem persists, the model might be unable to process this specific query." | |
def explain_solution(math_problem, solution): | |
"""Provide a more detailed explanation of the solution""" | |
if not solution or "error generating solution" in solution.lower() or "cannot solve" in solution.lower() : | |
return "Cannot explain: No valid solution provided." | |
try: | |
print(f"Generating detailed explanation with {EXPLANATION_MODEL}...") | |
model = genai.GenerativeModel( | |
model_name=EXPLANATION_MODEL, | |
generation_config={ | |
"temperature": 0.3, "top_p": 0.95, | |
"max_output_tokens": 2500, # Allow more tokens for detailed explanation | |
} | |
) | |
prompt = f""" | |
Task: Provide a detailed, pedagogical explanation of the provided solution to a math problem. Assume the reader found the original solution steps difficult to follow. | |
ORIGINAL PROBLEM: | |
``` | |
{math_problem} | |
``` | |
PROVIDED SOLUTION: | |
``` | |
{solution} | |
``` | |
Instructions: | |
Elaborate on the provided solution with the goal of enhancing understanding. Focus on the 'why' behind each step. | |
1. **Reiterate Goal:** Briefly restate the problem's objective. | |
2. **Core Concepts Deep Dive:** Explain the fundamental mathematical principles mentioned or implied in the solution in more detail. Use analogies or simpler examples if helpful. Define key terms. | |
3. **Step-by-Step Elaboration:** Go through the solution steps again, but expand on the reasoning. | |
* Why was this specific step taken? What rule or theorem justifies it? | |
* Are there intermediate calculations or assumptions that were skipped? Spell them out. | |
* Address potential points of confusion. | |
4. **Connections:** How does this problem relate to broader mathematical ideas or prerequisite knowledge? | |
5. **Common Pitfalls:** Mention common mistakes students make when tackling similar problems. | |
6. **Alternative Perspectives (Optional):** Briefly mention if there are other valid ways to approach the problem. | |
Format the output using Markdown for clarity (headings, lists, bold text). Use LaTeX for math notation (inline $, display $$). Make it easy to read and digest. | |
""" | |
response = model.generate_content(prompt, request_options={'timeout': 180}) | |
print("Detailed explanation generation complete.") | |
return response.text | |
except Exception as e: | |
print(f"Explanation Error: {e}") | |
print(traceback.format_exc()) | |
return f"## Error Generating Explanation\n\nAn error occurred: `{str(e)}`" | |
def generate_similar_problems(math_problem, classification): | |
"""Generate similar practice math problems""" | |
if not math_problem or len(math_problem) < 5: | |
return "Cannot generate similar problems: Invalid original problem text." | |
try: | |
print(f"Generating similar problems with {SIMILAR_MODEL}...") | |
model = genai.GenerativeModel( | |
model_name=SIMILAR_MODEL, | |
generation_config={ | |
"temperature": 0.7, "top_p": 0.95, # Higher temp for variety | |
"max_output_tokens": 1500, | |
} | |
) | |
# Ensure classification is a dict and format concepts | |
if not isinstance(classification, dict): | |
classification = {"category": "Unknown", "subtopic": "Unknown", "difficulty": "Unknown", "key_concepts": ["Unknown"]} | |
classification_str = f""" | |
- Category: {classification.get("category", "Unknown")} | |
- Subtopic: {classification.get("subtopic", "Unknown")} | |
- Difficulty: {classification.get("difficulty", "Unknown")} | |
- Key Concepts: {', '.join(classification.get("key_concepts", ["Unknown"]))} | |
""" | |
prompt = f""" | |
Task: Generate 3 distinct practice math problems that are similar in concept to the original problem provided, but vary slightly in presentation or difficulty. | |
ORIGINAL PROBLEM: | |
``` | |
{math_problem} | |
``` | |
CLASSIFICATION OF ORIGINAL PROBLEM: | |
{classification_str} | |
Instructions: | |
Create three new problems based on the original's concepts and difficulty level. | |
1. **Problem 1 (Similar Difficulty):** Create a problem that closely mirrors the original in terms of concepts and required steps, but uses different numbers, variables, or context. | |
2. **Problem 2 (Slightly Easier/Different Focus):** Create a problem that uses the same core concepts but might be slightly simpler, focus on a specific sub-step, or change the type of answer required (e.g., find an intermediate value instead of the final result). | |
3. **Problem 3 (Slightly Harder/Extension):** Create a problem that builds upon the original concepts, perhaps adding an extra step, combining it with another related concept, or requiring more complex manipulation. | |
For EACH of the 3 problems: | |
* Clearly state the problem question. Use LaTeX for math notation. | |
* Provide a one-sentence HINT on how to approach it. | |
* Provide the final ANSWER (just the answer, not the steps). | |
Format the output using Markdown. Use clear headings for each problem (e.g., "### Practice Problem 1 (Similar Difficulty)"). | |
""" | |
response = model.generate_content(prompt, request_options={'timeout': 180}) | |
print("Similar problems generation complete.") | |
return response.text | |
except Exception as e: | |
print(f"Similar Problems Error: {e}") | |
print(traceback.format_exc()) | |
return f"## Error Generating Similar Problems\n\nAn error occurred: `{str(e)}`" | |
# --- Main Processing Function (Modified for better progress updates and error handling) --- | |
def process_image(image): | |
"""Main processing pipeline for the NerdAI app""" | |
start_time = time.time() | |
try: | |
if image is None: | |
return None, "Please upload an image first.", "{}", "No image provided.", "", "No image provided." # Added state output | |
progress(0, desc="🚀 Starting...") | |
time.sleep(0.5) # Give UI time to update | |
# Step 1: Extract text | |
progress(0.1, desc="🔍 Extracting text from image...") | |
extracted_text = extract_text_with_gemini(image) | |
if not extracted_text or extracted_text.startswith("Error:") or len(extracted_text) < 10 : | |
err_msg = extracted_text if extracted_text.startswith("Error:") else "Error: Could not extract sufficient text from the image. Please try a clearer image or check Tesseract configuration if using fallback." | |
print(f"Text extraction failed or insufficient: {err_msg}") | |
# Show the uploaded image back to the user | |
img_display = image if isinstance(image, Image.Image) else Image.fromarray(image) | |
return img_display, err_msg, "{}", err_msg, "", err_msg # Return error message in multiple fields | |
progress(0.4, desc=f"📊 Classifying problem ({CLASSIFICATION_MODEL})...") | |
classification = classify_with_gemini_flash(extracted_text) | |
classification_json = json.dumps(classification, indent=2) | |
progress(0.6, desc=f"💡 Solving problem ({SOLUTION_MODEL})...") | |
solution = solve_with_gemini_pro(extracted_text, classification) | |
end_time = time.time() | |
progress(1.0, desc=f"✅ Done in {end_time - start_time:.2f}s!") | |
# Return processed image (or original), text, classification, solution, and update state | |
img_display = image if isinstance(image, Image.Image) else Image.fromarray(image) | |
return img_display, extracted_text, classification_json, solution, extracted_text, classification_json # Pass classification JSON to state too | |
except Exception as e: | |
print(f"Process Image Error: {e}") | |
print(traceback.format_exc()) | |
error_message = f"An unexpected error occurred: {str(e)}" | |
# Try to return the original image if possible | |
img_display = None | |
if image is not None: | |
img_display = image if isinstance(image, Image.Image) else Image.fromarray(image) | |
return img_display, error_message, "{}", error_message, "", error_message # Populate errors | |
# --- Gradio Interface (Major Changes Here) --- | |
# Custom CSS for styling | |
css = """ | |
body { font-family: 'Inter', sans-serif; } /* Modern font */ | |
.gradio-container { background-color: #f8f9fa; } /* Light background */ | |
#title_markdown h1 { | |
text-align: center; | |
color: #4A90E2; /* Theme color */ | |
font-weight: 600; | |
margin-bottom: 0px; /* Adjust spacing */ | |
} | |
#subtitle_markdown p { | |
text-align: center; | |
color: #555; | |
margin-top: 5px; /* Adjust spacing */ | |
margin-bottom: 20px; | |
} | |
/* Input/Output Image Area */ | |
#input_col, #output_col { padding: 10px; } | |
#input_image, #processed_image { | |
border-radius: 8px; /* Rounded corners for images */ | |
border: 1px solid #dee2e6; | |
overflow: hidden; /* Ensure border radius applies */ | |
height: 350px; /* Fixed height */ | |
object-fit: contain; /* Scale image nicely */ | |
} | |
#input_image div[data-testid="image"], #processed_image div[data-testid="image"] { | |
height: 100%; /* Make inner div fill height */ | |
} | |
#input_image img, #processed_image img { | |
height: 100%; object-fit: contain; /* Control image scaling */ | |
} | |
/* Main button */ | |
#process_button { margin-top: 15px; } | |
/* Output sections */ | |
#results_group { | |
border: 1px solid #e9ecef; | |
border-radius: 8px; | |
padding: 15px; | |
background-color: #ffffff; /* White background for results */ | |
box-shadow: 0 2px 4px rgba(0,0,0,0.05); /* Subtle shadow */ | |
margin-top: 20px; | |
} | |
#extracted_text_output textarea, #classification_output textarea { | |
background-color: #f1f3f4 !important; /* Light grey background for text boxes */ | |
border-radius: 4px; | |
} | |
#solution_output { margin-top: 15px; } | |
/* Action buttons below solution */ | |
#action_buttons { margin-top: 15px; margin-bottom: 15px; } | |
/* Accordion styling */ | |
.gradio-accordion > button { /* Target the accordion header button */ | |
background-color: #eef2f6; /* Lighter header background */ | |
border-radius: 5px 5px 0 0; /* Rounded top corners */ | |
font-weight: 500; | |
} | |
.gradio-accordion .gradio-markdown { /* Content inside accordion */ | |
border: 1px solid #dee2e6; | |
border-top: none; /* Remove top border as header has it */ | |
padding: 15px; | |
border-radius: 0 0 5px 5px; /* Rounded bottom corners */ | |
background-color: #fff; /* White background */ | |
} | |
footer { visibility: hidden } /* Hide default Gradio footer */ | |
""" | |
# Define a theme | |
# theme = gr_themes.Soft(primary_hue="blue", secondary_hue="sky") | |
theme = gr_themes.Default(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky).set( | |
# Further theme customizations if needed | |
button_primary_background_fill="#4A90E2", | |
button_primary_background_fill_hover="#357ABD", | |
button_secondary_background_fill="#E1E8ED", | |
button_secondary_background_fill_hover="#CED9E0", | |
block_radius="8px", # Consistent border radius | |
) | |
with gr.Blocks(theme=theme, css=css, title="NerdAI Math Solver") as demo: | |
# --- State Variables --- | |
# Store extracted text and classification needed for follow-up actions | |
extracted_text_state = gr.State("") | |
classification_state = gr.State("{}") # Store as JSON string | |
# --- UI Layout --- | |
gr.Markdown("# 🧠 NerdAI Math Problem Solver", elem_id="title_markdown") | |
gr.Markdown("Upload a clear image of a math problem. NerdAI will extract the text, classify it, solve it step-by-step, and offer further help!", elem_id="subtitle_markdown") | |
with gr.Row(): | |
with gr.Column(scale=1, elem_id="input_col"): | |
input_image = gr.Image(label="Upload Math Problem", type="pil", elem_id="input_image", height=350) # Set fixed height | |
process_btn = gr.Button("✨ Process Image and Solve", variant="primary", elem_id="process_button") | |
with gr.Column(scale=1, elem_id="output_col"): | |
processed_image = gr.Image(label="Processed Image", interactive=False, elem_id="processed_image", height=350) # Set fixed height | |
# --- Results Area --- | |
with gr.Group(elem_id="results_group"): | |
gr.Markdown("### Results") | |
with gr.Box(): # Box for slight visual separation | |
extracted_text_output = gr.Textbox(label="📝 Extracted Text", lines=3, interactive=False, placeholder="Text from the image will appear here...", elem_id="extracted_text_output") | |
with gr.Box(): | |
classification_output = gr.Textbox(label=f"📊 Problem Classification ({CLASSIFICATION_MODEL})", lines=5, interactive=False, placeholder="Problem type analysis will appear here...", elem_id="classification_output") | |
solution_output = gr.Markdown(label="✅ Solution Steps", value="*Solution steps will appear here after processing...*", elem_id="solution_output") | |
# --- Action Buttons --- | |
with gr.Row(elem_id="action_buttons"): | |
explain_btn = gr.Button("🤔 Explain Further", variant="secondary") | |
similar_btn = gr.Button("📚 Similar Questions", variant="secondary") | |
# --- Accordion for Detailed Outputs --- | |
with gr.Accordion("Detailed Explanation", open=False): | |
explanation_output = gr.Markdown(value="*Click 'Explain Further' above to get a detailed breakdown.*") | |
with gr.Accordion("Similar Practice Problems", open=False): | |
similar_problems_output = gr.Markdown(value="*Click 'Similar Questions' above to generate practice problems.*") | |
# --- Event Handlers --- | |
# Main process button click | |
process_btn.click( | |
fn=process_image, | |
inputs=[input_image], | |
outputs=[ | |
processed_image, | |
extracted_text_output, | |
classification_output, | |
solution_output, | |
extracted_text_state, # Update state | |
classification_state # Update state | |
], | |
# api_name="process_math_image" # Optional: for API usage | |
) | |
# Explain button click | |
def explain_button_handler(current_problem_text, current_solution_md): | |
"""Handler for Explain It button using state""" | |
print("Explain button clicked.") | |
if not current_problem_text or current_problem_text.startswith("Error:") or current_problem_text == "No image provided." or current_problem_text == "Please upload an image first.": | |
return "Please successfully process an image first to get text and a solution." | |
if not current_solution_md or current_solution_md.startswith("Error") or "will appear here" in current_solution_md: | |
return "Cannot explain: A valid solution needs to be generated first." | |
# Add a loading indicator (optional, but good UX) | |
yield "*Generating detailed explanation... please wait.*" | |
explanation_result = explain_solution(current_problem_text, current_solution_md) | |
yield explanation_result | |
explain_btn.click( | |
fn=explain_button_handler, | |
inputs=[extracted_text_state, solution_output], # Use state and current solution output | |
outputs=explanation_output # Target the Markdown inside the Accordion | |
) | |
# Similar problems button click | |
def similar_button_handler(current_problem_text, current_classification_json): | |
"""Handler for Similar Questions button using state""" | |
print("Similar button clicked.") | |
if not current_problem_text or current_problem_text.startswith("Error:") or current_problem_text == "No image provided." or current_problem_text == "Please upload an image first.": | |
return "Please successfully process an image first to get the problem text and classification." | |
# Add a loading indicator | |
yield "*Generating similar problems... please wait.*" | |
try: | |
classification_dict = json.loads(current_classification_json) | |
# Minimal validation | |
if not isinstance(classification_dict, dict) or not classification_dict: | |
raise ValueError("Invalid classification data.") | |
except (json.JSONDecodeError, ValueError) as e: | |
print(f"Error parsing classification state for similar problems: {e}") | |
return f"Error: Could not use problem classification data ({e}). Please ensure the problem was classified correctly." | |
similar_result = generate_similar_problems(current_problem_text, classification_dict) | |
yield similar_result | |
similar_btn.click( | |
fn=similar_button_handler, | |
inputs=[extracted_text_state, classification_state], # Use state | |
outputs=similar_problems_output # Target the Markdown inside the Accordion | |
) | |
# Add an example image (optional) | |
gr.Examples( | |
examples=[ | |
# Add paths to example images accessible by the script | |
["examples/algebra_problem.png"], | |
["examples/calculus_problem.jpg"], | |
["examples/geometry_problem.png"], | |
], | |
inputs=input_image, | |
# outputs=[processed_image, extracted_text_output, classification_output, solution_output, extracted_text_state, classification_state], # Outputs for examples if you want to auto-run them | |
# fn=process_image, # Function to run when example is clicked | |
cache_examples=False, # Better to re-run for dynamic models | |
label="Example Math Problems" | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Create dummy example files if they don't exist for local testing | |
if not os.path.exists("examples"): | |
os.makedirs("examples") | |
for fname in ["algebra_problem.png", "calculus_problem.jpg", "geometry_problem.png"]: | |
fpath = os.path.join("examples", fname) | |
if not os.path.exists(fpath): | |
try: | |
# Create a simple placeholder image | |
img = Image.new('RGB', (200, 100), color = (73, 109, 137)) | |
from PIL import ImageDraw | |
d = ImageDraw.Draw(img) | |
d.text((10,10), f"Placeholder for\n{fname}", fill=(255,255,0)) | |
img.save(fpath) | |
print(f"Created placeholder example: {fpath}") | |
except Exception as e: | |
print(f"Could not create placeholder image {fpath}: {e}") | |
# Recommended: Enable queue for better handling of multiple users/long tasks | |
demo.queue().launch(debug=True) # debug=True for more logs, remove for production |