CropGuard / app.py
mitraarka27's picture
🔧 Fix: change server_name to 0.0.0.0 for Hugging Face
62cebeb
# app.py
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import sys
import json
from PIL import Image
# Add project root to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.append(project_root)
# Internal modules
from src.model.architecture import build_model
from src.model.gradcam import GradCAMPlusPlus as GradCAM
from src.utils.config import BEST_MODEL_PATH
# Load disease information
with open("disease_info.json", "r") as f:
DISEASE_INFO = json.load(f)
# Load label mapping
with open("models/labels.json", "r") as f:
idx_to_class = json.load(f)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(idx_to_class)
model = build_model(num_classes=num_classes, freeze_backbone=False)
model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()
target_layer = model.features[-1]
# Sample images
SAMPLE_DIR = "sample_images"
sample_choices = sorted(os.listdir(SAMPLE_DIR))
# Utility Functions
def beautify_name(raw_classname):
parts = raw_classname.split("___")
if len(parts) >= 3:
plant = parts[1].title()
disease = parts[2].replace("_", " ").replace("(", "").replace(")", "").replace("__", " ").title()
return plant, disease
else:
return "Unknown", "Unknown"
def generate_gradcam(model, input_tensor, target_layer):
gradcam = GradCAM(model, target_layer)
cam = gradcam.generate(input_tensor)
return cam
def preprocess_image(image):
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image)
def predict(image):
if image is None:
return None, None, None, None, None
model.eval()
image_tensor = preprocess_image(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
probs = torch.softmax(output, dim=1).cpu().numpy()[0]
top3_indices = probs.argsort()[-3:][::-1]
top3_classes = [(idx_to_class[str(idx)], float(probs[idx])) for idx in top3_indices]
pred_class, pred_prob = top3_classes[0]
plant, disease = beautify_name(pred_class)
# Format Top-3 nicely
top3_text = ""
for i, (c, p) in enumerate(top3_classes, 1):
c_plant, c_disease = beautify_name(c)
top3_text += f"{i}. {c_plant} - {c_disease} ({p*100:.2f}%)\n"
# GradCAM
cam = generate_gradcam(model, image_tensor, target_layer)
img_np = np.array(image) / 255.0
# 🔥 Resize uploaded image to 224x224 for matching heatmap
img_np_resized = cv2.resize(img_np, (224, 224))
img_gray = cv2.cvtColor(np.uint8(img_np_resized * 255), cv2.COLOR_RGB2GRAY) / 255.0
img_gray_3ch = np.stack([img_gray]*3, axis=-1)
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_PLASMA)
heatmap = np.float32(heatmap) / 255
overlay = heatmap + img_gray_3ch
overlay = overlay / np.max(overlay)
fig, ax = plt.subplots(figsize=(5,5))
ax.imshow(overlay)
ax.axis("off")
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap='plasma'), orientation='horizontal', pad=0.05, ax=ax)
cbar.set_ticks([0, 0.5, 1])
cbar.set_ticklabels(['Low Focus', 'Medium Focus', 'High Focus'])
plt.tight_layout()
cam_path = "cam_output.png"
fig.savefig(cam_path)
plt.close(fig)
# Health Status + Disease Info
if "healthy" in pred_class.lower():
health_status = f"{plant} - Healthy"
identified_disease = "None"
disease_info_text = "✅ No disease detected."
else:
health_status = f"{plant} - Diseased"
identified_disease = disease
disease_data = DISEASE_INFO.get(pred_class, {})
disease_info_text = f"""
**Symptoms:** {disease_data.get('symptoms', 'No information available.')}
**Causes:** {disease_data.get('causes', 'No information available.')}
**Disease Cycle:** {disease_data.get('disease_cycle', 'No information available.')}
**Care & Treatment:** {disease_data.get('care_treatment', 'No information available.')}
[Learn more on Wikipedia]({disease_data.get('wiki_url', '#')})
"""
alert = None
if pred_prob < 0.6:
alert = "⚠️ Low confidence in prediction! Please verify manually."
return health_status, identified_disease, top3_text, alert, cam_path, disease_info_text
def load_sample_image(sample_name):
img_path = os.path.join(SAMPLE_DIR, sample_name)
img = Image.open(img_path).convert("RGB")
return img
# Interface
title = "CropGuard: Leaf Disease Detector"
copyright_text = "© 2025 Made by [Arka Mitra](https://github.com/mitraarka27)"
instruction_text = """
Upload a clear image of a **potato**, **tomato**, or **grape** leaf.
CropGuard will predict:
- Whether the leaf is **healthy** or **diseased**.
- The likely disease (if any).
- Where the model focused its attention.
⚡ **Note**: Currently supports only **Potato, Tomato, Grape** leaves.
"""
with gr.Blocks(theme="default") as app:
with gr.Row():
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
gr.Markdown("<p style='text-align: center;'>© 2025 Made by <a href='https://github.com/mitraarka27' target='_blank'>Arka Mitra</a></p>")
with gr.Row():
with gr.Column(scale=2):
upload = gr.Image(
type="pil",
sources=["upload", "webcam", "clipboard"],
label="Upload, Capture, or Paste Leaf Image"
)
gr.Markdown("**OR** choose from sample images below:")
sample_dropdown = gr.Dropdown(choices=sample_choices, label="Select a Sample Image")
load_btn = gr.Button("Load Sample Image")
predict_btn = gr.Button("Predict", variant="primary")
gr.Markdown(instruction_text)
alert_box = gr.Textbox(label="Prediction Alert", lines=2, interactive=False)
top3_preds = gr.Textbox(label="Top-3 Predictions", lines=5, interactive=False)
with gr.Column(scale=3):
health_status = gr.Label(label="Plant Health Status")
disease_name = gr.Label(label="Identified Disease (includes details)")
disease_info = gr.Markdown()
heatmap = gr.Image(label="Model Focus Heatmap")
load_btn.click(
fn=load_sample_image,
inputs=[sample_dropdown],
outputs=[upload]
)
predict_btn.click(
fn=predict,
inputs=[upload],
outputs=[health_status, disease_name, top3_preds, alert_box, heatmap, disease_info]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)