new
Browse files- app.py +100 -5
- requirements.txt +3 -1
app.py
CHANGED
@@ -9,6 +9,7 @@ from utils.ear_utils import BlinkDetector
|
|
9 |
from gradio_webrtc import WebRTC
|
10 |
from ultralytics import YOLO
|
11 |
import torch
|
|
|
12 |
|
13 |
def smooth_values(history, current_value, window_size=5):
|
14 |
if current_value is not None:
|
@@ -38,12 +39,12 @@ GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
|
|
38 |
DISTRACTION_MODEL_PATH = "best.pt"
|
39 |
|
40 |
# --- Global Initializations ---
|
41 |
-
|
|
|
42 |
blink_detector = BlinkDetector()
|
43 |
|
44 |
-
# Load Distraction Model
|
45 |
distraction_model = YOLO(DISTRACTION_MODEL_PATH)
|
46 |
-
distraction_model.to('cpu')
|
47 |
|
48 |
# Distraction Class Names
|
49 |
distraction_class_names = [
|
@@ -64,6 +65,9 @@ is_unconscious = False
|
|
64 |
frame_count_webcam = 0
|
65 |
stop_gaze_processing = False
|
66 |
|
|
|
|
|
|
|
67 |
# Constants
|
68 |
GAZE_STABILITY_THRESHOLD = 0.5
|
69 |
TIME_THRESHOLD = 15
|
@@ -285,10 +289,22 @@ def terminate_gaze_stream():
|
|
285 |
frame_count_webcam = 0
|
286 |
return "Gaze Processing Terminated. State Reset."
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
def process_gaze_frame(frame):
|
289 |
global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
|
290 |
global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
|
291 |
|
|
|
|
|
|
|
|
|
|
|
292 |
if stop_gaze_processing:
|
293 |
return np.zeros((480, 640, 3), dtype=np.uint8)
|
294 |
|
@@ -397,6 +413,66 @@ def process_gaze_frame(frame):
|
|
397 |
cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
|
398 |
return error_frame
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
def create_gaze_interface():
|
401 |
with gr.Blocks() as gaze_demo:
|
402 |
gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
|
@@ -425,6 +501,24 @@ def create_distraction_interface():
|
|
425 |
)
|
426 |
return distraction_demo
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
def create_video_interface():
|
429 |
video_demo = gr.Interface(
|
430 |
fn=analyze_video,
|
@@ -436,8 +530,8 @@ def create_video_interface():
|
|
436 |
return video_demo
|
437 |
|
438 |
demo = gr.TabbedInterface(
|
439 |
-
[create_video_interface(), create_gaze_interface(), create_distraction_interface()],
|
440 |
-
["Gaze Video Upload", "Gaze & Drowsiness (Live)", "Distraction Video Upload"],
|
441 |
title="Driver Monitoring System"
|
442 |
)
|
443 |
|
@@ -453,4 +547,5 @@ if __name__ == "__main__":
|
|
453 |
is_unconscious = False
|
454 |
frame_count_webcam = 0
|
455 |
stop_gaze_processing = False
|
|
|
456 |
demo.launch()
|
|
|
9 |
from gradio_webrtc import WebRTC
|
10 |
from ultralytics import YOLO
|
11 |
import torch
|
12 |
+
import spaces # Add spaces import
|
13 |
|
14 |
def smooth_values(history, current_value, window_size=5):
|
15 |
if current_value is not None:
|
|
|
39 |
DISTRACTION_MODEL_PATH = "best.pt"
|
40 |
|
41 |
# --- Global Initializations ---
|
42 |
+
# Load models on CPU initially
|
43 |
+
gaze_predictor = GazePredictor(GAZE_MODEL_PATH, device='cpu') # Assuming GazePredictor accepts device arg
|
44 |
blink_detector = BlinkDetector()
|
45 |
|
46 |
+
# Load Distraction Model on CPU initially
|
47 |
distraction_model = YOLO(DISTRACTION_MODEL_PATH)
|
|
|
48 |
|
49 |
# Distraction Class Names
|
50 |
distraction_class_names = [
|
|
|
65 |
frame_count_webcam = 0
|
66 |
stop_gaze_processing = False
|
67 |
|
68 |
+
# --- Global State Variables for Distraction Webcam ---
|
69 |
+
stop_distraction_processing = False
|
70 |
+
|
71 |
# Constants
|
72 |
GAZE_STABILITY_THRESHOLD = 0.5
|
73 |
TIME_THRESHOLD = 15
|
|
|
289 |
frame_count_webcam = 0
|
290 |
return "Gaze Processing Terminated. State Reset."
|
291 |
|
292 |
+
def terminate_distraction_stream():
|
293 |
+
global stop_distraction_processing
|
294 |
+
print("Distraction Live Termination signal received. Stopping processing.")
|
295 |
+
stop_distraction_processing = True
|
296 |
+
return "Distraction Live Processing Terminated."
|
297 |
+
|
298 |
+
@spaces.GPU # Add ZeroGPU decorator
|
299 |
def process_gaze_frame(frame):
|
300 |
global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
|
301 |
global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
|
302 |
|
303 |
+
try:
|
304 |
+
gaze_predictor.model.to('cuda')
|
305 |
+
except Exception as e:
|
306 |
+
print(f"Warning: Could not move gaze model to CUDA: {e}")
|
307 |
+
|
308 |
if stop_gaze_processing:
|
309 |
return np.zeros((480, 640, 3), dtype=np.uint8)
|
310 |
|
|
|
413 |
cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
|
414 |
return error_frame
|
415 |
|
416 |
+
@spaces.GPU # Add ZeroGPU decorator
|
417 |
+
def process_distraction_frame(frame):
|
418 |
+
global stop_distraction_processing
|
419 |
+
|
420 |
+
distraction_model.to('cuda')
|
421 |
+
|
422 |
+
if stop_distraction_processing:
|
423 |
+
return np.zeros((480, 640, 3), dtype=np.uint8)
|
424 |
+
|
425 |
+
if frame is None:
|
426 |
+
return np.zeros((480, 640, 3), dtype=np.uint8)
|
427 |
+
|
428 |
+
try:
|
429 |
+
frame_to_process = frame
|
430 |
+
|
431 |
+
results = distraction_model(frame_to_process, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
|
432 |
+
|
433 |
+
display_text = "safe driving"
|
434 |
+
alarm_action = None
|
435 |
+
|
436 |
+
for result in results:
|
437 |
+
if result.boxes is not None and len(result.boxes) > 0:
|
438 |
+
boxes = result.boxes.xyxy.cpu().numpy()
|
439 |
+
scores = result.boxes.conf.cpu().numpy()
|
440 |
+
classes = result.boxes.cls.cpu().numpy()
|
441 |
+
|
442 |
+
if len(boxes) > 0:
|
443 |
+
max_score_idx = scores.argmax()
|
444 |
+
detected_action_idx = int(classes[max_score_idx])
|
445 |
+
if 0 <= detected_action_idx < len(distraction_class_names):
|
446 |
+
detected_action = distraction_class_names[detected_action_idx]
|
447 |
+
confidence = scores[max_score_idx]
|
448 |
+
display_text = f"{detected_action}: {confidence:.2f}"
|
449 |
+
if detected_action != 'safe driving':
|
450 |
+
alarm_action = detected_action
|
451 |
+
else:
|
452 |
+
print(f"Warning: Detected class index {detected_action_idx} out of bounds.")
|
453 |
+
display_text = "Unknown Detection"
|
454 |
+
|
455 |
+
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
456 |
+
if alarm_action:
|
457 |
+
print(f"ALARM: Unsafe behavior detected - {alarm_action}!")
|
458 |
+
cv2.putText(frame_bgr, f"ALARM: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
459 |
+
|
460 |
+
text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255)
|
461 |
+
cv2.putText(frame_bgr, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
|
462 |
+
|
463 |
+
frame_rgb_processed = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
|
464 |
+
return frame_rgb_processed
|
465 |
+
|
466 |
+
except Exception as e:
|
467 |
+
print(f"Error processing distraction frame: {e}")
|
468 |
+
error_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
469 |
+
if not error_frame.flags.writeable:
|
470 |
+
error_frame = error_frame.copy()
|
471 |
+
error_frame_bgr = cv2.cvtColor(error_frame, cv2.COLOR_RGB2BGR)
|
472 |
+
cv2.putText(error_frame_bgr, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
|
473 |
+
error_frame_rgb = cv2.cvtColor(error_frame_bgr, cv2.COLOR_BGR2RGB)
|
474 |
+
return error_frame_rgb
|
475 |
+
|
476 |
def create_gaze_interface():
|
477 |
with gr.Blocks() as gaze_demo:
|
478 |
gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
|
|
|
501 |
)
|
502 |
return distraction_demo
|
503 |
|
504 |
+
def create_distraction_live_interface():
|
505 |
+
with gr.Blocks() as distraction_live_demo:
|
506 |
+
gr.Markdown("## Real-time Distraction Detection (Live)")
|
507 |
+
with gr.Row():
|
508 |
+
webcam_stream = WebRTC(label="Webcam Stream")
|
509 |
+
with gr.Row():
|
510 |
+
terminate_btn = gr.Button("Terminate Process")
|
511 |
+
|
512 |
+
webcam_stream.stream(
|
513 |
+
fn=process_distraction_frame,
|
514 |
+
inputs=[webcam_stream],
|
515 |
+
outputs=[webcam_stream]
|
516 |
+
)
|
517 |
+
|
518 |
+
terminate_btn.click(fn=terminate_distraction_stream, inputs=None, outputs=None)
|
519 |
+
|
520 |
+
return distraction_live_demo
|
521 |
+
|
522 |
def create_video_interface():
|
523 |
video_demo = gr.Interface(
|
524 |
fn=analyze_video,
|
|
|
530 |
return video_demo
|
531 |
|
532 |
demo = gr.TabbedInterface(
|
533 |
+
[create_video_interface(), create_gaze_interface(), create_distraction_interface(), create_distraction_live_interface()],
|
534 |
+
["Gaze Video Upload", "Gaze & Drowsiness (Live)", "Distraction Video Upload", "Distraction Detection (Live)"],
|
535 |
title="Driver Monitoring System"
|
536 |
)
|
537 |
|
|
|
547 |
is_unconscious = False
|
548 |
frame_count_webcam = 0
|
549 |
stop_gaze_processing = False
|
550 |
+
stop_distraction_processing = False
|
551 |
demo.launch()
|
requirements.txt
CHANGED
@@ -11,4 +11,6 @@ tensorflow
|
|
11 |
pygame
|
12 |
twilio
|
13 |
ultralytics==8.3.93
|
14 |
-
torch==2.6.0
|
|
|
|
|
|
11 |
pygame
|
12 |
twilio
|
13 |
ultralytics==8.3.93
|
14 |
+
# torch==2.6.0 # Replace with ZeroGPU compatible version, e.g., 2.4.0
|
15 |
+
torch==2.4.0 # Example compatible version
|
16 |
+
spaces # Add spaces for ZeroGPU
|