Spaces:
Sleeping
Sleeping
File size: 6,902 Bytes
09823ea 62cebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# app.py
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import sys
import json
from PIL import Image
# Add project root to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.append(project_root)
# Internal modules
from src.model.architecture import build_model
from src.model.gradcam import GradCAMPlusPlus as GradCAM
from src.utils.config import BEST_MODEL_PATH
# Load disease information
with open("disease_info.json", "r") as f:
DISEASE_INFO = json.load(f)
# Load label mapping
with open("models/labels.json", "r") as f:
idx_to_class = json.load(f)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(idx_to_class)
model = build_model(num_classes=num_classes, freeze_backbone=False)
model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()
target_layer = model.features[-1]
# Sample images
SAMPLE_DIR = "sample_images"
sample_choices = sorted(os.listdir(SAMPLE_DIR))
# Utility Functions
def beautify_name(raw_classname):
parts = raw_classname.split("___")
if len(parts) >= 3:
plant = parts[1].title()
disease = parts[2].replace("_", " ").replace("(", "").replace(")", "").replace("__", " ").title()
return plant, disease
else:
return "Unknown", "Unknown"
def generate_gradcam(model, input_tensor, target_layer):
gradcam = GradCAM(model, target_layer)
cam = gradcam.generate(input_tensor)
return cam
def preprocess_image(image):
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image)
def predict(image):
if image is None:
return None, None, None, None, None
model.eval()
image_tensor = preprocess_image(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
probs = torch.softmax(output, dim=1).cpu().numpy()[0]
top3_indices = probs.argsort()[-3:][::-1]
top3_classes = [(idx_to_class[str(idx)], float(probs[idx])) for idx in top3_indices]
pred_class, pred_prob = top3_classes[0]
plant, disease = beautify_name(pred_class)
# Format Top-3 nicely
top3_text = ""
for i, (c, p) in enumerate(top3_classes, 1):
c_plant, c_disease = beautify_name(c)
top3_text += f"{i}. {c_plant} - {c_disease} ({p*100:.2f}%)\n"
# GradCAM
cam = generate_gradcam(model, image_tensor, target_layer)
img_np = np.array(image) / 255.0
# 🔥 Resize uploaded image to 224x224 for matching heatmap
img_np_resized = cv2.resize(img_np, (224, 224))
img_gray = cv2.cvtColor(np.uint8(img_np_resized * 255), cv2.COLOR_RGB2GRAY) / 255.0
img_gray_3ch = np.stack([img_gray]*3, axis=-1)
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_PLASMA)
heatmap = np.float32(heatmap) / 255
overlay = heatmap + img_gray_3ch
overlay = overlay / np.max(overlay)
fig, ax = plt.subplots(figsize=(5,5))
ax.imshow(overlay)
ax.axis("off")
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap='plasma'), orientation='horizontal', pad=0.05, ax=ax)
cbar.set_ticks([0, 0.5, 1])
cbar.set_ticklabels(['Low Focus', 'Medium Focus', 'High Focus'])
plt.tight_layout()
cam_path = "cam_output.png"
fig.savefig(cam_path)
plt.close(fig)
# Health Status + Disease Info
if "healthy" in pred_class.lower():
health_status = f"{plant} - Healthy"
identified_disease = "None"
disease_info_text = "✅ No disease detected."
else:
health_status = f"{plant} - Diseased"
identified_disease = disease
disease_data = DISEASE_INFO.get(pred_class, {})
disease_info_text = f"""
**Symptoms:** {disease_data.get('symptoms', 'No information available.')}
**Causes:** {disease_data.get('causes', 'No information available.')}
**Disease Cycle:** {disease_data.get('disease_cycle', 'No information available.')}
**Care & Treatment:** {disease_data.get('care_treatment', 'No information available.')}
[Learn more on Wikipedia]({disease_data.get('wiki_url', '#')})
"""
alert = None
if pred_prob < 0.6:
alert = "⚠️ Low confidence in prediction! Please verify manually."
return health_status, identified_disease, top3_text, alert, cam_path, disease_info_text
def load_sample_image(sample_name):
img_path = os.path.join(SAMPLE_DIR, sample_name)
img = Image.open(img_path).convert("RGB")
return img
# Interface
title = "CropGuard: Leaf Disease Detector"
copyright_text = "© 2025 Made by [Arka Mitra](https://github.com/mitraarka27)"
instruction_text = """
Upload a clear image of a **potato**, **tomato**, or **grape** leaf.
CropGuard will predict:
- Whether the leaf is **healthy** or **diseased**.
- The likely disease (if any).
- Where the model focused its attention.
⚡ **Note**: Currently supports only **Potato, Tomato, Grape** leaves.
"""
with gr.Blocks(theme="default") as app:
with gr.Row():
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
gr.Markdown("<p style='text-align: center;'>© 2025 Made by <a href='https://github.com/mitraarka27' target='_blank'>Arka Mitra</a></p>")
with gr.Row():
with gr.Column(scale=2):
upload = gr.Image(
type="pil",
sources=["upload", "webcam", "clipboard"],
label="Upload, Capture, or Paste Leaf Image"
)
gr.Markdown("**OR** choose from sample images below:")
sample_dropdown = gr.Dropdown(choices=sample_choices, label="Select a Sample Image")
load_btn = gr.Button("Load Sample Image")
predict_btn = gr.Button("Predict", variant="primary")
gr.Markdown(instruction_text)
alert_box = gr.Textbox(label="Prediction Alert", lines=2, interactive=False)
top3_preds = gr.Textbox(label="Top-3 Predictions", lines=5, interactive=False)
with gr.Column(scale=3):
health_status = gr.Label(label="Plant Health Status")
disease_name = gr.Label(label="Identified Disease (includes details)")
disease_info = gr.Markdown()
heatmap = gr.Image(label="Model Focus Heatmap")
load_btn.click(
fn=load_sample_image,
inputs=[sample_dropdown],
outputs=[upload]
)
predict_btn.click(
fn=predict,
inputs=[upload],
outputs=[health_status, disease_name, top3_preds, alert_box, heatmap, disease_info]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860) |