KRISH09bha's picture
Update app.py
497cb0f verified
raw
history blame
4.76 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import FileResponse
import cv2
import numpy as np
from ultralytics import YOLO
import base64
import os
import shutil
import tempfile
# Initialize FastAPI app
app = FastAPI()
# Load YOLO model safely
MODEL_PATH = "12x.pt"
if not os.path.exists(MODEL_PATH):
print(f"⚠ Warning: Model file '{MODEL_PATH}' not found. API will not work properly.")
model = None
else:
model = YOLO(MODEL_PATH)
def process_frame(frame):
"""Process a single frame with YOLO and return predictions."""
if model is None:
return [], {}
results = model(frame)
predictions = []
object_count = {}
for result in results:
for box in result.boxes:
class_id = int(box.cls)
class_name = model.names.get(class_id, f"Unknown_{class_id}") # Handle missing class names
predictions.append({
"class": class_name,
"confidence": float(box.conf),
"bbox": [float(x) for x in box.xyxy[0]]
})
# Count objects
object_count[class_name] = object_count.get(class_name, 0) + 1
return predictions, object_count
@app.post("/upload-image/")
async def upload_image(file: UploadFile = File(...)):
"""Upload an image and get object detection results."""
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
predictions, object_count = process_frame(img)
# Draw bounding boxes on the image
for pred in predictions:
x1, y1, x2, y2 = map(int, pred["bbox"])
label = f"{pred['class']} ({pred['confidence']:.2f})"
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
_, buffer = cv2.imencode('.jpg', img)
img_base64 = base64.b64encode(buffer).decode('utf-8')
return {
"image": f"data:image/jpeg;base64,{img_base64}",
"object_count": object_count
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/upload-video/")
async def upload_video(file: UploadFile = File(...)):
"""Upload a video, process it frame by frame, and return the processed video."""
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.")
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
shutil.copyfileobj(file.file, temp_video)
temp_video_path = temp_video.name
cap = cv2.VideoCapture(temp_video_path)
if not cap.isOpened():
os.remove(temp_video_path)
raise HTTPException(status_code=400, detail="Could not open video file.")
# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# Output video
output_video_path = temp_video_path.replace(".mp4", "_processed.mp4")
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
frame_interval = 5 # Process every 5th frame for efficiency
frame_index = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_index % frame_interval == 0:
predictions, _ = process_frame(frame)
# Draw bounding boxes on the frame
for pred in predictions:
x1, y1, x2, y2 = map(int, pred["bbox"])
label = f"{pred['class']} ({pred['confidence']:.2f})"
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
out.write(frame)
frame_index += 1
cap.release()
out.release()
os.remove(temp_video_path) # Clean up temp file
return FileResponse(output_video_path, media_type="video/mp4", filename="processed_video.mp4")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")
@app.get("/")
def home():
return {"message": "🎯 Object Detection API for Images and Videos using 12x.pt"}