from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import FileResponse import cv2 import numpy as np from ultralytics import YOLO import base64 import os import shutil import tempfile # Initialize FastAPI app app = FastAPI() # Load YOLO model safely MODEL_PATH = "12x.pt" if not os.path.exists(MODEL_PATH): print(f"⚠ Warning: Model file '{MODEL_PATH}' not found. API will not work properly.") model = None else: model = YOLO(MODEL_PATH) def process_frame(frame): """Process a single frame with YOLO and return predictions.""" if model is None: return [], {} results = model(frame) predictions = [] object_count = {} for result in results: for box in result.boxes: class_id = int(box.cls) class_name = model.names.get(class_id, f"Unknown_{class_id}") # Handle missing class names predictions.append({ "class": class_name, "confidence": float(box.conf), "bbox": [float(x) for x in box.xyxy[0]] }) # Count objects object_count[class_name] = object_count.get(class_name, 0) + 1 return predictions, object_count @app.post("/upload-image/") async def upload_image(file: UploadFile = File(...)): """Upload an image and get object detection results.""" if model is None: raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.") try: contents = await file.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) predictions, object_count = process_frame(img) # Draw bounding boxes on the image for pred in predictions: x1, y1, x2, y2 = map(int, pred["bbox"]) label = f"{pred['class']} ({pred['confidence']:.2f})" cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) _, buffer = cv2.imencode('.jpg', img) img_base64 = base64.b64encode(buffer).decode('utf-8') return { "image": f"data:image/jpeg;base64,{img_base64}", "object_count": object_count } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.post("/upload-video/") async def upload_video(file: UploadFile = File(...)): """Upload a video, process it frame by frame, and return the processed video.""" if model is None: raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.") try: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video: shutil.copyfileobj(file.file, temp_video) temp_video_path = temp_video.name cap = cv2.VideoCapture(temp_video_path) if not cap.isOpened(): os.remove(temp_video_path) raise HTTPException(status_code=400, detail="Could not open video file.") # Get video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Output video output_video_path = temp_video_path.replace(".mp4", "_processed.mp4") out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) frame_interval = 5 # Process every 5th frame for efficiency frame_index = 0 while True: ret, frame = cap.read() if not ret: break if frame_index % frame_interval == 0: predictions, _ = process_frame(frame) # Draw bounding boxes on the frame for pred in predictions: x1, y1, x2, y2 = map(int, pred["bbox"]) label = f"{pred['class']} ({pred['confidence']:.2f})" cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) out.write(frame) frame_index += 1 cap.release() out.release() os.remove(temp_video_path) # Clean up temp file return FileResponse(output_video_path, media_type="video/mp4", filename="processed_video.mp4") except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}") @app.get("/") def home(): return {"message": "🎯 Object Detection API for Images and Videos using 12x.pt"}