KRISH09bha commited on
Commit
497cb0f
·
verified ·
1 Parent(s): 3bc40d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -60
app.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import FastAPI, File, UploadFile
 
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
@@ -11,12 +12,12 @@ import tempfile
11
  app = FastAPI()
12
 
13
  # Load YOLO model safely
14
- model_path = "12x.pt"
15
- if not os.path.exists(model_path):
16
- print(f"Warning: Model file '{model_path}' not found. API will not work properly.")
17
- model = None # Handle model loading failure
18
  else:
19
- model = YOLO(model_path)
20
 
21
 
22
  def process_frame(frame):
@@ -49,82 +50,91 @@ def process_frame(frame):
49
  async def upload_image(file: UploadFile = File(...)):
50
  """Upload an image and get object detection results."""
51
  if model is None:
52
- return {"error": "Model not loaded. Please upload '12x.pt' to run detection."}
53
 
54
- contents = await file.read()
55
- nparr = np.frombuffer(contents, np.uint8)
56
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
 
57
 
58
- predictions, object_count = process_frame(img)
59
 
60
- # Draw bounding boxes on the image
61
- for pred in predictions:
62
- x1, y1, x2, y2 = map(int, pred["bbox"])
63
- label = f"{pred['class']} ({pred['confidence']:.2f})"
64
- cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
65
- cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
66
 
67
- _, buffer = cv2.imencode('.jpg', img)
68
- img_base64 = base64.b64encode(buffer).decode('utf-8')
69
 
70
- return {
71
- "image": f"data:image/jpeg;base64,{img_base64}",
72
- "object_count": object_count
73
- }
 
 
 
74
 
75
 
76
  @app.post("/upload-video/")
77
  async def upload_video(file: UploadFile = File(...)):
78
  """Upload a video, process it frame by frame, and return the processed video."""
79
  if model is None:
80
- return {"error": "Model not loaded. Please upload '12x.pt' to run detection."}
 
 
 
 
 
 
 
 
 
 
81
 
82
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
83
- shutil.copyfileobj(file.file, temp_video)
84
- temp_video_path = temp_video.name
 
 
85
 
86
- cap = cv2.VideoCapture(temp_video_path)
87
- if not cap.isOpened():
88
- os.remove(temp_video_path)
89
- return {"error": "Could not open video file"}
90
 
91
- # Get video properties
92
- fps = int(cap.get(cv2.CAP_PROP_FPS))
93
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
96
 
97
- # Output video
98
- output_video_path = temp_video_path.replace(".mp4", "_processed.mp4")
99
- out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
 
100
 
101
- frame_interval = 5 # Process every 5th frame for efficiency
102
- frame_index = 0
103
 
104
- while True:
105
- ret, frame = cap.read()
106
- if not ret:
107
- break
 
 
108
 
109
- if frame_index % frame_interval == 0:
110
- predictions, _ = process_frame(frame)
111
 
112
- # Draw bounding boxes on the frame
113
- for pred in predictions:
114
- x1, y1, x2, y2 = map(int, pred["bbox"])
115
- label = f"{pred['class']} ({pred['confidence']:.2f})"
116
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
117
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
118
 
119
- out.write(frame)
120
- frame_index += 1
121
 
122
- cap.release()
123
- out.release()
124
- os.remove(temp_video_path) # Clean up temp file
125
 
126
- return FileResponse(output_video_path, media_type="video/mp4", filename="processed_video.mp4")
127
 
128
  @app.get("/")
129
  def home():
130
- return {"message": "Object Detection API for Images and Videos using 12x.pt"}
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import FileResponse
3
  import cv2
4
  import numpy as np
5
  from ultralytics import YOLO
 
12
  app = FastAPI()
13
 
14
  # Load YOLO model safely
15
+ MODEL_PATH = "12x.pt"
16
+ if not os.path.exists(MODEL_PATH):
17
+ print(f"Warning: Model file '{MODEL_PATH}' not found. API will not work properly.")
18
+ model = None
19
  else:
20
+ model = YOLO(MODEL_PATH)
21
 
22
 
23
  def process_frame(frame):
 
50
  async def upload_image(file: UploadFile = File(...)):
51
  """Upload an image and get object detection results."""
52
  if model is None:
53
+ raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.")
54
 
55
+ try:
56
+ contents = await file.read()
57
+ nparr = np.frombuffer(contents, np.uint8)
58
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
59
 
60
+ predictions, object_count = process_frame(img)
61
 
62
+ # Draw bounding boxes on the image
63
+ for pred in predictions:
64
+ x1, y1, x2, y2 = map(int, pred["bbox"])
65
+ label = f"{pred['class']} ({pred['confidence']:.2f})"
66
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
67
+ cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
68
 
69
+ _, buffer = cv2.imencode('.jpg', img)
70
+ img_base64 = base64.b64encode(buffer).decode('utf-8')
71
 
72
+ return {
73
+ "image": f"data:image/jpeg;base64,{img_base64}",
74
+ "object_count": object_count
75
+ }
76
+
77
+ except Exception as e:
78
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
79
 
80
 
81
  @app.post("/upload-video/")
82
  async def upload_video(file: UploadFile = File(...)):
83
  """Upload a video, process it frame by frame, and return the processed video."""
84
  if model is None:
85
+ raise HTTPException(status_code=500, detail="Model not loaded. Please upload '12x.pt' to run detection.")
86
+
87
+ try:
88
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
89
+ shutil.copyfileobj(file.file, temp_video)
90
+ temp_video_path = temp_video.name
91
+
92
+ cap = cv2.VideoCapture(temp_video_path)
93
+ if not cap.isOpened():
94
+ os.remove(temp_video_path)
95
+ raise HTTPException(status_code=400, detail="Could not open video file.")
96
 
97
+ # Get video properties
98
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
99
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
100
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
101
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
102
 
103
+ # Output video
104
+ output_video_path = temp_video_path.replace(".mp4", "_processed.mp4")
105
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
 
106
 
107
+ frame_interval = 5 # Process every 5th frame for efficiency
108
+ frame_index = 0
 
 
 
109
 
110
+ while True:
111
+ ret, frame = cap.read()
112
+ if not ret:
113
+ break
114
 
115
+ if frame_index % frame_interval == 0:
116
+ predictions, _ = process_frame(frame)
117
 
118
+ # Draw bounding boxes on the frame
119
+ for pred in predictions:
120
+ x1, y1, x2, y2 = map(int, pred["bbox"])
121
+ label = f"{pred['class']} ({pred['confidence']:.2f})"
122
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
123
+ cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
124
 
125
+ out.write(frame)
126
+ frame_index += 1
127
 
128
+ cap.release()
129
+ out.release()
130
+ os.remove(temp_video_path) # Clean up temp file
 
 
 
131
 
132
+ return FileResponse(output_video_path, media_type="video/mp4", filename="processed_video.mp4")
 
133
 
134
+ except Exception as e:
135
+ raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")
 
136
 
 
137
 
138
  @app.get("/")
139
  def home():
140
+ return {"message": "🎯 Object Detection API for Images and Videos using 12x.pt"}