KRISH09bha's picture
Update app.py
8a1d8af verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import FileResponse
import cv2
import numpy as np
from ultralytics import YOLO
import base64
import os
import shutil
import tempfile
# Initialize FastAPI app
app = FastAPI()
# Load YOLO model safely
MODEL_PATH = "12x.pt"
if not os.path.exists(MODEL_PATH):
print(f"⚠ Warning: Model file '{MODEL_PATH}' not found. API will not work properly.")
model = None
else:
model = YOLO(MODEL_PATH)
def process_frame(frame):
"""Process a single frame with YOLO and return predictions."""
if model is None:
return [], {}
results = model(frame)
predictions = []
object_count = {}
for result in results:
for box in result.boxes:
class_id = int(box.cls)
class_name = model.names.get(class_id, f"Unknown_{class_id}") # Handle missing class names
predictions.append({
"class": class_name,
"confidence": float(box.conf),
"bbox": [float(x) for x in box.xyxy[0]]
})
# Count objects
object_count[class_name] = object_count.get(class_name, 0) + 1
return predictions, object_count
@app.post("/upload-image/")
async def upload_image(file: UploadFile = File(...)):
"""Upload an image and get object detection results."""
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
predictions, object_count = process_frame(img)
# Draw bounding boxes on the image
for pred in predictions:
x1, y1, x2, y2 = map(int, pred["bbox"])
label = f"{pred['class']} ({pred['confidence']:.2f})"
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
_, buffer = cv2.imencode('.jpg', img)
img_base64 = base64.b64encode(buffer).decode('utf-8')
return {
"image": f"data:image/jpeg;base64,{img_base64}",
"object_count": object_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.get("/")
def home():
return {"message": "🎯 Object Detection API for Images and Videos using 12x.pt"}