from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import FileResponse import cv2 import numpy as np from ultralytics import YOLO import base64 import os import shutil import tempfile # Initialize FastAPI app app = FastAPI() # Load YOLO model safely MODEL_PATH = "12x.pt" if not os.path.exists(MODEL_PATH): print(f"⚠ Warning: Model file '{MODEL_PATH}' not found. API will not work properly.") model = None else: model = YOLO(MODEL_PATH) def process_frame(frame): """Process a single frame with YOLO and return predictions.""" if model is None: return [], {} results = model(frame) predictions = [] object_count = {} for result in results: for box in result.boxes: class_id = int(box.cls) class_name = model.names.get(class_id, f"Unknown_{class_id}") # Handle missing class names predictions.append({ "class": class_name, "confidence": float(box.conf), "bbox": [float(x) for x in box.xyxy[0]] }) # Count objects object_count[class_name] = object_count.get(class_name, 0) + 1 return predictions, object_count @app.post("/upload-image/") async def upload_image(file: UploadFile = File(...)): """Upload an image and get object detection results.""" if model is None: raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.") try: contents = await file.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) predictions, object_count = process_frame(img) # Draw bounding boxes on the image for pred in predictions: x1, y1, x2, y2 = map(int, pred["bbox"]) label = f"{pred['class']} ({pred['confidence']:.2f})" cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) _, buffer = cv2.imencode('.jpg', img) img_base64 = base64.b64encode(buffer).decode('utf-8') return { "image": f"data:image/jpeg;base64,{img_base64}", "object_count": object_count } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.get("/") def home(): return {"message": "🎯 Object Detection API for Images and Videos using 12x.pt"}