yolo_hippo / app.py
hippoiam10's picture
Update app.py
ca83661 verified
import torch
import torchvision
from torchvision import transforms
from ultralytics import YOLO
import gradio as gr
import urllib.request
from PIL import Image, ImageDraw
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
model_url = "https://huggingface.co/Ultralytics/YOLOv8/resolve/main/yolov8n.pt"
model_path = "yolov8n.pt"
# 下載模型
urllib.request.urlretrieve(model_url, model_path)
yolo_model = YOLO(model_path)
# 載入 Faster R-CNN 模型
faster_rcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="COCO_V1")
faster_rcnn_model.eval()
# 偵測函數
def detect_objects(image):
transform = transforms.Compose([transforms.ToTensor()])
img_tensor = transform(image).unsqueeze(0)
# YOLO 偵測
yolo_results = yolo_model(image)
yolo_image = yolo_results[0].plot() # YOLO 偵測結果
yolo_boxes = yolo_results[0].boxes.xyxy.cpu().numpy()
yolo_confidence = yolo_results[0].boxes.conf.cpu().numpy()
# Faster R-CNN 偵測
with torch.no_grad():
prediction = faster_rcnn_model(img_tensor)
rcnn_boxes = prediction[0]["boxes"].cpu().numpy()
rcnn_scores = prediction[0]["scores"].cpu().numpy()
# Faster R-CNN 畫框
rcnn_image = image.copy()
draw = ImageDraw.Draw(rcnn_image)
for i in range(len(rcnn_scores)):
if rcnn_scores[i] > 0.5:
box = rcnn_boxes[i]
draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline="red", width=3)
# 評估指標
def evaluate_model(pred_boxes, confs):
y_true = []
y_pred = []
for i in range(len(pred_boxes)):
if confs[i] > 0.5:
y_true.append(1)
y_pred.append(1)
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
return precision, recall, f1
# 計算 YOLO 和 Faster R-CNN 的評估指標
yolo_precision, yolo_recall, yolo_f1 = evaluate_model(yolo_boxes, yolo_confidence)
rcnn_precision, rcnn_recall, rcnn_f1 = evaluate_model(rcnn_boxes, rcnn_scores)
evaluation_results = {
"YOLO": {
"Precision": round(yolo_precision, 3),
"Recall": round(yolo_recall, 3),
"F1-Score": round(yolo_f1, 3),
"Confidence": round(np.mean(yolo_confidence), 3),
},
"Faster R-CNN": {
"Precision": round(rcnn_precision, 3),
"Recall": round(rcnn_recall, 3),
"F1-Score": round(rcnn_f1, 3),
"Confidence": round(np.mean(rcnn_scores), 3),
}
}
return Image.fromarray(yolo_image), rcnn_image, evaluation_results
# Gradio 介面
demo = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil", label="上傳圖片"),
outputs=[
gr.Image(type="pil", label="YOLO 偵測結果"),
gr.Image(type="pil", label="Faster R-CNN 偵測結果"),
gr.JSON(label="評估指標")
],
title="YOLO vs Faster R-CNN 物件偵測",
description="上傳圖片,系統將使用 YOLOv8 和 Faster R-CNN 進行偵測並顯示結果。",
)
# 使用 gradio deploy 而非 launch
demo.queue() # 啟用佇列,確保請求不會超載
demo.launch(server_name="0.0.0.0", server_port=7860)