KRISH09bha commited on
Commit
a67480e
·
verified ·
1 Parent(s): 533a626

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -44
app.py CHANGED
@@ -1,5 +1,4 @@
1
  from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import StreamingResponse
3
  import cv2
4
  import numpy as np
5
  from ultralytics import YOLO
@@ -7,27 +6,33 @@ import base64
7
  import os
8
  import shutil
9
  import tempfile
10
- import asyncio
11
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
- # Load YOLO model (Ensure 12x.pt exists)
16
  model_path = "12x.pt"
17
  if not os.path.exists(model_path):
18
- raise FileNotFoundError(f"Model file '{model_path}' not found. Please place it in the project directory.")
 
 
 
19
 
20
- model = YOLO(model_path)
21
 
22
  def process_frame(frame):
23
  """Process a single frame with YOLO and return predictions."""
 
 
 
24
  results = model(frame)
25
  predictions = []
26
  object_count = {}
27
 
28
  for result in results:
29
  for box in result.boxes:
30
- class_name = result.names[int(box.cls)]
 
 
31
  predictions.append({
32
  "class": class_name,
33
  "confidence": float(box.conf),
@@ -39,48 +44,20 @@ def process_frame(frame):
39
 
40
  return predictions, object_count
41
 
42
- def encode_frame(frame):
43
- """Encode a frame as JPEG and return base64-encoded string."""
44
- _, buffer = cv2.imencode('.jpg', frame)
45
- return base64.b64encode(buffer).decode('utf-8')
46
-
47
- @app.get("/video-stream/")
48
- async def video_stream():
49
- """Endpoint to stream video frames with real-time object detection."""
50
- cap = cv2.VideoCapture(0)
51
- if not cap.isOpened():
52
- return {"error": "Could not open webcam"}
53
-
54
- async def generate():
55
- while True:
56
- ret, frame = cap.read()
57
- if not ret:
58
- break
59
-
60
- predictions, _ = process_frame(frame)
61
-
62
- # Draw bounding boxes
63
- for pred in predictions:
64
- x1, y1, x2, y2 = map(int, pred["bbox"])
65
- label = f"{pred['class']} ({pred['confidence']:.2f})"
66
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
67
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
68
-
69
- _, buffer = cv2.imencode('.jpg', frame)
70
- yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + buffer.tobytes() + b'\r\n')
71
- await asyncio.sleep(0.1) # Adjust frame rate
72
-
73
- return StreamingResponse(generate(), media_type="multipart/x-mixed-replace; boundary=frame")
74
 
75
  @app.post("/upload-image/")
76
  async def upload_image(file: UploadFile = File(...)):
77
- """Endpoint to upload an image and get object detection results."""
 
 
 
78
  contents = await file.read()
79
  nparr = np.frombuffer(contents, np.uint8)
80
  img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
81
 
82
  predictions, object_count = process_frame(img)
83
 
 
84
  for pred in predictions:
85
  x1, y1, x2, y2 = map(int, pred["bbox"])
86
  label = f"{pred['class']} ({pred['confidence']:.2f})"
@@ -90,17 +67,25 @@ async def upload_image(file: UploadFile = File(...)):
90
  _, buffer = cv2.imencode('.jpg', img)
91
  img_base64 = base64.b64encode(buffer).decode('utf-8')
92
 
93
- return {"image": img_base64, "object_count": object_count}
 
 
 
 
94
 
95
  @app.post("/upload-video/")
96
  async def upload_video(file: UploadFile = File(...)):
97
- """Endpoint to upload a video, process it frame by frame, and return detection results."""
 
 
 
98
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
99
  shutil.copyfileobj(file.file, temp_video)
100
  temp_video_path = temp_video.name
101
 
102
  cap = cv2.VideoCapture(temp_video_path)
103
  if not cap.isOpened():
 
104
  return {"error": "Could not open video file"}
105
 
106
  frame_results = []
@@ -116,16 +101,18 @@ async def upload_video(file: UploadFile = File(...)):
116
  predictions, object_count = process_frame(frame)
117
  frame_results.append({
118
  "frame_index": frame_index,
119
- "object_count": object_count
 
120
  })
121
 
122
  frame_index += 1
123
 
124
  cap.release()
125
- os.remove(temp_video_path)
126
 
127
  return {"video_results": frame_results}
128
 
 
129
  @app.get("/")
130
  def home():
131
- return {"message": "Real-Time Object Detection API with Image, Video, and Streaming Support"}
 
1
  from fastapi import FastAPI, File, UploadFile
 
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
 
6
  import os
7
  import shutil
8
  import tempfile
 
9
 
10
  # Initialize FastAPI app
11
  app = FastAPI()
12
 
13
+ # Load YOLO model safely
14
  model_path = "12x.pt"
15
  if not os.path.exists(model_path):
16
+ print(f"Warning: Model file '{model_path}' not found. API will not work properly.")
17
+ model = None # Handle model loading failure
18
+ else:
19
+ model = YOLO(model_path)
20
 
 
21
 
22
  def process_frame(frame):
23
  """Process a single frame with YOLO and return predictions."""
24
+ if model is None:
25
+ return [], {}
26
+
27
  results = model(frame)
28
  predictions = []
29
  object_count = {}
30
 
31
  for result in results:
32
  for box in result.boxes:
33
+ class_id = int(box.cls)
34
+ class_name = model.names.get(class_id, f"Unknown_{class_id}") # Handle missing class names
35
+
36
  predictions.append({
37
  "class": class_name,
38
  "confidence": float(box.conf),
 
44
 
45
  return predictions, object_count
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @app.post("/upload-image/")
49
  async def upload_image(file: UploadFile = File(...)):
50
+ """Upload an image and get object detection results."""
51
+ if model is None:
52
+ return {"error": "Model not loaded. Please upload '12x.pt' to run detection."}
53
+
54
  contents = await file.read()
55
  nparr = np.frombuffer(contents, np.uint8)
56
  img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
57
 
58
  predictions, object_count = process_frame(img)
59
 
60
+ # Draw bounding boxes on the image
61
  for pred in predictions:
62
  x1, y1, x2, y2 = map(int, pred["bbox"])
63
  label = f"{pred['class']} ({pred['confidence']:.2f})"
 
67
  _, buffer = cv2.imencode('.jpg', img)
68
  img_base64 = base64.b64encode(buffer).decode('utf-8')
69
 
70
+ return {
71
+ "image": f"data:image/jpeg;base64,{img_base64}",
72
+ "object_count": object_count
73
+ }
74
+
75
 
76
  @app.post("/upload-video/")
77
  async def upload_video(file: UploadFile = File(...)):
78
+ """Upload a video, process it frame by frame, and return detection results."""
79
+ if model is None:
80
+ return {"error": "Model not loaded. Please upload '12x.pt' to run detection."}
81
+
82
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
83
  shutil.copyfileobj(file.file, temp_video)
84
  temp_video_path = temp_video.name
85
 
86
  cap = cv2.VideoCapture(temp_video_path)
87
  if not cap.isOpened():
88
+ os.remove(temp_video_path)
89
  return {"error": "Could not open video file"}
90
 
91
  frame_results = []
 
101
  predictions, object_count = process_frame(frame)
102
  frame_results.append({
103
  "frame_index": frame_index,
104
+ "object_count": object_count,
105
+ "detections": predictions
106
  })
107
 
108
  frame_index += 1
109
 
110
  cap.release()
111
+ os.remove(temp_video_path) # Clean up temporary file
112
 
113
  return {"video_results": frame_results}
114
 
115
+
116
  @app.get("/")
117
  def home():
118
+ return {"message": "Object Detection API for Images and Videos using 12x.pt"}