hafizarslan commited on
Commit
8445179
·
verified ·
1 Parent(s): c78d0bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -26
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import cv2
2
- import torch
3
  import numpy as np
4
  from PIL import Image
 
5
  from torchvision import models, transforms
6
  from ultralytics import YOLO
7
  import gradio as gr
@@ -11,10 +11,8 @@ import torch.nn as nn
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Load models
14
- yolo_model = YOLO('best.pt') # Make sure this file is uploaded to your Space
15
  resnet = models.resnet50(pretrained=False)
16
-
17
- # Modify ResNet for 3 classes
18
  resnet.fc = nn.Linear(resnet.fc.in_features, 3)
19
  resnet.load_state_dict(torch.load('rice_resnet_model.pth', map_location=device))
20
  resnet = resnet.to(device)
@@ -38,12 +36,17 @@ def classify_crop(crop_img):
38
  _, predicted = torch.max(output, 1)
39
  return class_labels[predicted.item()]
40
 
41
- def detect_and_classify(image):
42
- """Process full image with YOLO + ResNet"""
43
- image = np.array(image)
 
 
 
 
44
  results = yolo_model(image)[0]
45
  boxes = results.boxes.xyxy.cpu().numpy()
46
-
 
47
  for box in boxes:
48
  x1, y1, x2, y2 = map(int, box[:4])
49
  crop = image[y1:y2, x1:x2]
@@ -52,35 +55,36 @@ def detect_and_classify(image):
52
 
53
  # Draw bounding box and label
54
  cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
55
- cv2.putText(image, predicted_label, (x1, y1-10),
56
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
 
 
 
 
 
57
 
 
58
  return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
59
 
60
- # Gradio Interface
61
- with gr.Blocks(title="چاول کا شناختی نظام") as demo:
62
  gr.Markdown("""
63
- # چاول کا شناختی نظام
64
- ایک تصویر اپ لوڈ کریں جس میں چاول کے دانے ہوں۔ نظام ہر دانے کو پہچان کر اس کی قسم بتائے گا۔
65
  """)
66
 
67
  with gr.Row():
68
- input_image = gr.Image(type="pil", label="تصویر داخل کریں")
69
- output_image = gr.Image(type="pil", label="نتیجہ")
 
 
 
70
 
71
- submit_btn = gr.Button("تشخیص کریں")
72
  submit_btn.click(
73
  fn=detect_and_classify,
74
- inputs=input_image,
75
  outputs=output_image
76
  )
77
-
78
- gr.Examples(
79
- examples=[["example1.jpg"], ["example2.jpg"]], # Add your example images
80
- inputs=input_image,
81
- outputs=output_image,
82
- fn=detect_and_classify,
83
- cache_examples=True
84
- )
85
 
 
86
  demo.launch()
 
1
  import cv2
 
2
  import numpy as np
3
  from PIL import Image
4
+ import torch
5
  from torchvision import models, transforms
6
  from ultralytics import YOLO
7
  import gradio as gr
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
  # Load models
14
+ yolo_model = YOLO('best.pt') # Make sure this file is uploaded
15
  resnet = models.resnet50(pretrained=False)
 
 
16
  resnet.fc = nn.Linear(resnet.fc.in_features, 3)
17
  resnet.load_state_dict(torch.load('rice_resnet_model.pth', map_location=device))
18
  resnet = resnet.to(device)
 
36
  _, predicted = torch.max(output, 1)
37
  return class_labels[predicted.item()]
38
 
39
+ def detect_and_classify(input_image):
40
+ """Process uploaded image"""
41
+ # Convert Gradio Image to OpenCV format
42
+ image = np.array(input_image)
43
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
44
+
45
+ # YOLO Detection
46
  results = yolo_model(image)[0]
47
  boxes = results.boxes.xyxy.cpu().numpy()
48
+
49
+ # Process each detection
50
  for box in boxes:
51
  x1, y1, x2, y2 = map(int, box[:4])
52
  crop = image[y1:y2, x1:x2]
 
55
 
56
  # Draw bounding box and label
57
  cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
58
+ cv2.putText(image,
59
+ predicted_label,
60
+ (x1, y1-10),
61
+ cv2.FONT_HERSHEY_SIMPLEX,
62
+ 0.9,
63
+ (36, 255, 12),
64
+ 2)
65
 
66
+ # Convert back to RGB for Gradio
67
  return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
68
 
69
+ # Create Gradio interface
70
+ with gr.Blocks(title="Rice Classification") as demo:
71
  gr.Markdown("""
72
+ ## 🍚 Rice Variety Classifier
73
+ Upload an image containing rice grains. The system will detect and classify each grain.
74
  """)
75
 
76
  with gr.Row():
77
+ with gr.Column():
78
+ image_input = gr.Image(type="pil", label="Upload Rice Image")
79
+ submit_btn = gr.Button("Analyze", variant="primary")
80
+ with gr.Column():
81
+ output_image = gr.Image(label="Detection Results", interactive=False)
82
 
 
83
  submit_btn.click(
84
  fn=detect_and_classify,
85
+ inputs=image_input,
86
  outputs=output_image
87
  )
 
 
 
 
 
 
 
 
88
 
89
+ # Launch the app
90
  demo.launch()