Spaces:
Sleeping
Sleeping
File size: 2,746 Bytes
81f02dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import models, transforms
from ultralytics import YOLO
import gradio as gr
import torch.nn as nn
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
yolo_model = YOLO('best.pt') # Make sure this file is uploaded to your Space
resnet = models.resnet50(pretrained=False)
# Modify ResNet for 3 classes
resnet.fc = nn.Linear(resnet.fc.in_features, 3)
resnet.load_state_dict(torch.load('rice_resnet_model.pth', map_location=device))
resnet = resnet.to(device)
resnet.eval()
# Class labels
class_labels = ["c9", "kant", "superf"]
# Image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def classify_crop(crop_img):
"""Classify a single rice grain"""
image = transform(crop_img).unsqueeze(0).to(device)
with torch.no_grad():
output = resnet(image)
_, predicted = torch.max(output, 1)
return class_labels[predicted.item()]
def detect_and_classify(image):
"""Process full image with YOLO + ResNet"""
image = np.array(image)
results = yolo_model(image)[0]
boxes = results.boxes.xyxy.cpu().numpy()
for box in boxes:
x1, y1, x2, y2 = map(int, box[:4])
crop = image[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
predicted_label = classify_crop(crop_pil)
# Draw bounding box and label
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, predicted_label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Gradio Interface
with gr.Blocks(title="چاول کا شناختی نظام") as demo:
gr.Markdown("""
# چاول کا شناختی نظام
ایک تصویر اپ لوڈ کریں جس میں چاول کے دانے ہوں۔ نظام ہر دانے کو پہچان کر اس کی قسم بتائے گا۔
""")
with gr.Row():
input_image = gr.Image(type="pil", label="تصویر داخل کریں")
output_image = gr.Image(type="pil", label="نتیجہ")
submit_btn = gr.Button("تشخیص کریں")
submit_btn.click(
fn=detect_and_classify,
inputs=input_image,
outputs=output_image
)
gr.Examples(
examples=[["example1.jpg"], ["example2.jpg"]], # Add your example images
inputs=input_image,
outputs=output_image,
fn=detect_and_classify,
cache_examples=True
)
demo.launch() |