# app.py import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt import cv2 import os import sys import json from PIL import Image # Add project root to sys.path current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) if project_root not in sys.path: sys.path.append(project_root) # Internal modules from src.model.architecture import build_model from src.model.gradcam import GradCAMPlusPlus as GradCAM from src.utils.config import BEST_MODEL_PATH # Load disease information with open("disease_info.json", "r") as f: DISEASE_INFO = json.load(f) # Load label mapping with open("models/labels.json", "r") as f: idx_to_class = json.load(f) # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_classes = len(idx_to_class) model = build_model(num_classes=num_classes, freeze_backbone=False) model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device)) model = model.to(device) model.eval() target_layer = model.features[-1] # Sample images SAMPLE_DIR = "sample_images" sample_choices = sorted(os.listdir(SAMPLE_DIR)) # Utility Functions def beautify_name(raw_classname): parts = raw_classname.split("___") if len(parts) >= 3: plant = parts[1].title() disease = parts[2].replace("_", " ").replace("(", "").replace(")", "").replace("__", " ").title() return plant, disease else: return "Unknown", "Unknown" def generate_gradcam(model, input_tensor, target_layer): gradcam = GradCAM(model, target_layer) cam = gradcam.generate(input_tensor) return cam def preprocess_image(image): from torchvision import transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image) def predict(image): if image is None: return None, None, None, None, None model.eval() image_tensor = preprocess_image(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image_tensor) probs = torch.softmax(output, dim=1).cpu().numpy()[0] top3_indices = probs.argsort()[-3:][::-1] top3_classes = [(idx_to_class[str(idx)], float(probs[idx])) for idx in top3_indices] pred_class, pred_prob = top3_classes[0] plant, disease = beautify_name(pred_class) # Format Top-3 nicely top3_text = "" for i, (c, p) in enumerate(top3_classes, 1): c_plant, c_disease = beautify_name(c) top3_text += f"{i}. {c_plant} - {c_disease} ({p*100:.2f}%)\n" # GradCAM cam = generate_gradcam(model, image_tensor, target_layer) img_np = np.array(image) / 255.0 # 🔥 Resize uploaded image to 224x224 for matching heatmap img_np_resized = cv2.resize(img_np, (224, 224)) img_gray = cv2.cvtColor(np.uint8(img_np_resized * 255), cv2.COLOR_RGB2GRAY) / 255.0 img_gray_3ch = np.stack([img_gray]*3, axis=-1) heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_PLASMA) heatmap = np.float32(heatmap) / 255 overlay = heatmap + img_gray_3ch overlay = overlay / np.max(overlay) fig, ax = plt.subplots(figsize=(5,5)) ax.imshow(overlay) ax.axis("off") cbar = plt.colorbar(plt.cm.ScalarMappable(cmap='plasma'), orientation='horizontal', pad=0.05, ax=ax) cbar.set_ticks([0, 0.5, 1]) cbar.set_ticklabels(['Low Focus', 'Medium Focus', 'High Focus']) plt.tight_layout() cam_path = "cam_output.png" fig.savefig(cam_path) plt.close(fig) # Health Status + Disease Info if "healthy" in pred_class.lower(): health_status = f"{plant} - Healthy" identified_disease = "None" disease_info_text = "✅ No disease detected." else: health_status = f"{plant} - Diseased" identified_disease = disease disease_data = DISEASE_INFO.get(pred_class, {}) disease_info_text = f""" **Symptoms:** {disease_data.get('symptoms', 'No information available.')} **Causes:** {disease_data.get('causes', 'No information available.')} **Disease Cycle:** {disease_data.get('disease_cycle', 'No information available.')} **Care & Treatment:** {disease_data.get('care_treatment', 'No information available.')} [Learn more on Wikipedia]({disease_data.get('wiki_url', '#')}) """ alert = None if pred_prob < 0.6: alert = "⚠️ Low confidence in prediction! Please verify manually." return health_status, identified_disease, top3_text, alert, cam_path, disease_info_text def load_sample_image(sample_name): img_path = os.path.join(SAMPLE_DIR, sample_name) img = Image.open(img_path).convert("RGB") return img # Interface title = "CropGuard: Leaf Disease Detector" copyright_text = "© 2025 Made by [Arka Mitra](https://github.com/mitraarka27)" instruction_text = """ Upload a clear image of a **potato**, **tomato**, or **grape** leaf. CropGuard will predict: - Whether the leaf is **healthy** or **diseased**. - The likely disease (if any). - Where the model focused its attention. ⚡ **Note**: Currently supports only **Potato, Tomato, Grape** leaves. """ with gr.Blocks(theme="default") as app: with gr.Row(): gr.Markdown(f"

{title}

") gr.Markdown("

© 2025 Made by Arka Mitra

") with gr.Row(): with gr.Column(scale=2): upload = gr.Image( type="pil", sources=["upload", "webcam", "clipboard"], label="Upload, Capture, or Paste Leaf Image" ) gr.Markdown("**OR** choose from sample images below:") sample_dropdown = gr.Dropdown(choices=sample_choices, label="Select a Sample Image") load_btn = gr.Button("Load Sample Image") predict_btn = gr.Button("Predict", variant="primary") gr.Markdown(instruction_text) alert_box = gr.Textbox(label="Prediction Alert", lines=2, interactive=False) top3_preds = gr.Textbox(label="Top-3 Predictions", lines=5, interactive=False) with gr.Column(scale=3): health_status = gr.Label(label="Plant Health Status") disease_name = gr.Label(label="Identified Disease (includes details)") disease_info = gr.Markdown() heatmap = gr.Image(label="Model Focus Heatmap") load_btn.click( fn=load_sample_image, inputs=[sample_dropdown], outputs=[upload] ) predict_btn.click( fn=predict, inputs=[upload], outputs=[health_status, disease_name, top3_preds, alert_box, heatmap, disease_info] ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)