hafizarslan's picture
Update app.py
81f02dd verified
raw
history blame
2.75 kB
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import models, transforms
from ultralytics import YOLO
import gradio as gr
import torch.nn as nn
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
yolo_model = YOLO('best.pt') # Make sure this file is uploaded to your Space
resnet = models.resnet50(pretrained=False)
# Modify ResNet for 3 classes
resnet.fc = nn.Linear(resnet.fc.in_features, 3)
resnet.load_state_dict(torch.load('rice_resnet_model.pth', map_location=device))
resnet = resnet.to(device)
resnet.eval()
# Class labels
class_labels = ["c9", "kant", "superf"]
# Image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def classify_crop(crop_img):
"""Classify a single rice grain"""
image = transform(crop_img).unsqueeze(0).to(device)
with torch.no_grad():
output = resnet(image)
_, predicted = torch.max(output, 1)
return class_labels[predicted.item()]
def detect_and_classify(image):
"""Process full image with YOLO + ResNet"""
image = np.array(image)
results = yolo_model(image)[0]
boxes = results.boxes.xyxy.cpu().numpy()
for box in boxes:
x1, y1, x2, y2 = map(int, box[:4])
crop = image[y1:y2, x1:x2]
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
predicted_label = classify_crop(crop_pil)
# Draw bounding box and label
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, predicted_label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Gradio Interface
with gr.Blocks(title="چاول کا شناختی نظام") as demo:
gr.Markdown("""
# چاول کا شناختی نظام
ایک تصویر اپ لوڈ کریں جس میں چاول کے دانے ہوں۔ نظام ہر دانے کو پہچان کر اس کی قسم بتائے گا۔
""")
with gr.Row():
input_image = gr.Image(type="pil", label="تصویر داخل کریں")
output_image = gr.Image(type="pil", label="نتیجہ")
submit_btn = gr.Button("تشخیص کریں")
submit_btn.click(
fn=detect_and_classify,
inputs=input_image,
outputs=output_image
)
gr.Examples(
examples=[["example1.jpg"], ["example2.jpg"]], # Add your example images
inputs=input_image,
outputs=output_image,
fn=detect_and_classify,
cache_examples=True
)
demo.launch()