File size: 3,387 Bytes
4a0cd82
7f97dd6
ee62e84
4557350
75bac2f
2dd2d70
4f80d37
75bac2f
ee62e84
4f80d37
75bac2f
 
4f80d37
f305096
4f80d37
 
ee62e84
4f80d37
 
 
 
 
75bac2f
4f80d37
ee62e84
 
75bac2f
4f80d37
75bac2f
 
 
4f80d37
 
 
 
 
 
 
 
 
 
75bac2f
4f80d37
 
75bac2f
4f80d37
 
75bac2f
 
4f80d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75bac2f
ee62e84
4f80d37
4557350
75bac2f
ee62e84
4f80d37
75bac2f
4f80d37
ee62e84
4f80d37
 
 
 
 
 
 
 
 
 
 
 
ee62e84
4a0cd82
 
4f80d37
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from ultralyticsplus import YOLO, render_result
import numpy as np
import time
import torch

# System Configuration
print("\n" + "="*40)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print("="*40 + "\n")

# Load model with optimized parameters for leaf counting
model = YOLO('foduucom/plant-leaf-detection-and-classification')

# Custom configuration for leaf counting
model.overrides.update({
    'conf': 0.15,  # Lower confidence threshold for better recall
    'iou': 0.25,    # Lower IoU threshold for overlapping leaves
    'imgsz': 1280,  # Higher resolution for small leaves
    'agnostic_nms': False,
    'max_det': 300,  # Higher maximum detections
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'classes': None,  # Detect all classes (leaves only in this model)
    'half': torch.cuda.is_available()
})

def count_leaves(image):
    try:
        start_time = time.time()
        
        # Preprocessing - enhance contrast
        image = np.array(image)
        lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
        cl = clahe.apply(l)
        limg = cv2.merge((cl,a,b))
        enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
        
        # Prediction with overlap handling
        results = model.predict(
            source=enhanced_img,
            augment=True,  # Test time augmentation
            verbose=False,
            agnostic_nms=False,
            overlap_mask=False
        )
        
        # Post-processing for overlapping leaves
        boxes = results[0].boxes
        valid_boxes = []
        
        # Filter small detections and merge overlapping
        for box in boxes:
            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
            w = x2 - x1
            h = y2 - y1
            
            # Filter too small boxes (adjust based on your leaf sizes)
            if w > 20 and h > 20:
                valid_boxes.append(box)
        
        # Improved NMS for overlapping leaves
        from utils.nms import non_max_suppression
        final_boxes = non_max_suppression(
            torch.stack([b.xywh[0] for b in valid_boxes]),
            conf_thres=0.1,
            iou_thres=0.15,
            multi_label=False
        )
        
        num_leaves = len(final_boxes)
        
        # Visual validation
        debug_img = enhanced_img.copy()
        for box in final_boxes:
            x1, y1, x2, y2 = map(int, box[:4])
            cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0,255,0), 2)
        
        print(f"Processing time: {time.time()-start_time:.2f}s")
        return debug_img, num_leaves
    
    except Exception as e:
        print(f"Error: {str(e)}")
        return image, 0

# Gradio interface with visualization
interface = gr.Interface(
    fn=count_leaves,
    inputs=gr.Image(label="Input Image"),
    outputs=[
        gr.Image(label="Detection Visualization"),
        gr.Number(label="Estimated Leaf Count")
    ],
    title="πŸƒ Advanced Leaf Counter",
    description="Specialized for overlapping leaves and dense foliage",
    examples=[
        ["sample_leaf1.jpg"],
        ["sample_leaf2.jpg"]
    ]
)

if __name__ == "__main__":
    interface.launch(
        server_port=7860,
        share=False
    )