Guru-25 commited on
Commit
14b9b6f
·
verified ·
1 Parent(s): cb29a61
Files changed (2) hide show
  1. app.py +100 -5
  2. requirements.txt +3 -1
app.py CHANGED
@@ -9,6 +9,7 @@ from utils.ear_utils import BlinkDetector
9
  from gradio_webrtc import WebRTC
10
  from ultralytics import YOLO
11
  import torch
 
12
 
13
  def smooth_values(history, current_value, window_size=5):
14
  if current_value is not None:
@@ -38,12 +39,12 @@ GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
38
  DISTRACTION_MODEL_PATH = "best.pt"
39
 
40
  # --- Global Initializations ---
41
- gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
 
42
  blink_detector = BlinkDetector()
43
 
44
- # Load Distraction Model
45
  distraction_model = YOLO(DISTRACTION_MODEL_PATH)
46
- distraction_model.to('cpu')
47
 
48
  # Distraction Class Names
49
  distraction_class_names = [
@@ -64,6 +65,9 @@ is_unconscious = False
64
  frame_count_webcam = 0
65
  stop_gaze_processing = False
66
 
 
 
 
67
  # Constants
68
  GAZE_STABILITY_THRESHOLD = 0.5
69
  TIME_THRESHOLD = 15
@@ -285,10 +289,22 @@ def terminate_gaze_stream():
285
  frame_count_webcam = 0
286
  return "Gaze Processing Terminated. State Reset."
287
 
 
 
 
 
 
 
 
288
  def process_gaze_frame(frame):
289
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
290
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
291
 
 
 
 
 
 
292
  if stop_gaze_processing:
293
  return np.zeros((480, 640, 3), dtype=np.uint8)
294
 
@@ -397,6 +413,66 @@ def process_gaze_frame(frame):
397
  cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
398
  return error_frame
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  def create_gaze_interface():
401
  with gr.Blocks() as gaze_demo:
402
  gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
@@ -425,6 +501,24 @@ def create_distraction_interface():
425
  )
426
  return distraction_demo
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  def create_video_interface():
429
  video_demo = gr.Interface(
430
  fn=analyze_video,
@@ -436,8 +530,8 @@ def create_video_interface():
436
  return video_demo
437
 
438
  demo = gr.TabbedInterface(
439
- [create_video_interface(), create_gaze_interface(), create_distraction_interface()],
440
- ["Gaze Video Upload", "Gaze & Drowsiness (Live)", "Distraction Video Upload"],
441
  title="Driver Monitoring System"
442
  )
443
 
@@ -453,4 +547,5 @@ if __name__ == "__main__":
453
  is_unconscious = False
454
  frame_count_webcam = 0
455
  stop_gaze_processing = False
 
456
  demo.launch()
 
9
  from gradio_webrtc import WebRTC
10
  from ultralytics import YOLO
11
  import torch
12
+ import spaces # Add spaces import
13
 
14
  def smooth_values(history, current_value, window_size=5):
15
  if current_value is not None:
 
39
  DISTRACTION_MODEL_PATH = "best.pt"
40
 
41
  # --- Global Initializations ---
42
+ # Load models on CPU initially
43
+ gaze_predictor = GazePredictor(GAZE_MODEL_PATH, device='cpu') # Assuming GazePredictor accepts device arg
44
  blink_detector = BlinkDetector()
45
 
46
+ # Load Distraction Model on CPU initially
47
  distraction_model = YOLO(DISTRACTION_MODEL_PATH)
 
48
 
49
  # Distraction Class Names
50
  distraction_class_names = [
 
65
  frame_count_webcam = 0
66
  stop_gaze_processing = False
67
 
68
+ # --- Global State Variables for Distraction Webcam ---
69
+ stop_distraction_processing = False
70
+
71
  # Constants
72
  GAZE_STABILITY_THRESHOLD = 0.5
73
  TIME_THRESHOLD = 15
 
289
  frame_count_webcam = 0
290
  return "Gaze Processing Terminated. State Reset."
291
 
292
+ def terminate_distraction_stream():
293
+ global stop_distraction_processing
294
+ print("Distraction Live Termination signal received. Stopping processing.")
295
+ stop_distraction_processing = True
296
+ return "Distraction Live Processing Terminated."
297
+
298
+ @spaces.GPU # Add ZeroGPU decorator
299
  def process_gaze_frame(frame):
300
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
301
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
302
 
303
+ try:
304
+ gaze_predictor.model.to('cuda')
305
+ except Exception as e:
306
+ print(f"Warning: Could not move gaze model to CUDA: {e}")
307
+
308
  if stop_gaze_processing:
309
  return np.zeros((480, 640, 3), dtype=np.uint8)
310
 
 
413
  cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
414
  return error_frame
415
 
416
+ @spaces.GPU # Add ZeroGPU decorator
417
+ def process_distraction_frame(frame):
418
+ global stop_distraction_processing
419
+
420
+ distraction_model.to('cuda')
421
+
422
+ if stop_distraction_processing:
423
+ return np.zeros((480, 640, 3), dtype=np.uint8)
424
+
425
+ if frame is None:
426
+ return np.zeros((480, 640, 3), dtype=np.uint8)
427
+
428
+ try:
429
+ frame_to_process = frame
430
+
431
+ results = distraction_model(frame_to_process, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
432
+
433
+ display_text = "safe driving"
434
+ alarm_action = None
435
+
436
+ for result in results:
437
+ if result.boxes is not None and len(result.boxes) > 0:
438
+ boxes = result.boxes.xyxy.cpu().numpy()
439
+ scores = result.boxes.conf.cpu().numpy()
440
+ classes = result.boxes.cls.cpu().numpy()
441
+
442
+ if len(boxes) > 0:
443
+ max_score_idx = scores.argmax()
444
+ detected_action_idx = int(classes[max_score_idx])
445
+ if 0 <= detected_action_idx < len(distraction_class_names):
446
+ detected_action = distraction_class_names[detected_action_idx]
447
+ confidence = scores[max_score_idx]
448
+ display_text = f"{detected_action}: {confidence:.2f}"
449
+ if detected_action != 'safe driving':
450
+ alarm_action = detected_action
451
+ else:
452
+ print(f"Warning: Detected class index {detected_action_idx} out of bounds.")
453
+ display_text = "Unknown Detection"
454
+
455
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
456
+ if alarm_action:
457
+ print(f"ALARM: Unsafe behavior detected - {alarm_action}!")
458
+ cv2.putText(frame_bgr, f"ALARM: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
459
+
460
+ text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255)
461
+ cv2.putText(frame_bgr, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
462
+
463
+ frame_rgb_processed = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
464
+ return frame_rgb_processed
465
+
466
+ except Exception as e:
467
+ print(f"Error processing distraction frame: {e}")
468
+ error_frame = np.zeros((480, 640, 3), dtype=np.uint8)
469
+ if not error_frame.flags.writeable:
470
+ error_frame = error_frame.copy()
471
+ error_frame_bgr = cv2.cvtColor(error_frame, cv2.COLOR_RGB2BGR)
472
+ cv2.putText(error_frame_bgr, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
473
+ error_frame_rgb = cv2.cvtColor(error_frame_bgr, cv2.COLOR_BGR2RGB)
474
+ return error_frame_rgb
475
+
476
  def create_gaze_interface():
477
  with gr.Blocks() as gaze_demo:
478
  gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
 
501
  )
502
  return distraction_demo
503
 
504
+ def create_distraction_live_interface():
505
+ with gr.Blocks() as distraction_live_demo:
506
+ gr.Markdown("## Real-time Distraction Detection (Live)")
507
+ with gr.Row():
508
+ webcam_stream = WebRTC(label="Webcam Stream")
509
+ with gr.Row():
510
+ terminate_btn = gr.Button("Terminate Process")
511
+
512
+ webcam_stream.stream(
513
+ fn=process_distraction_frame,
514
+ inputs=[webcam_stream],
515
+ outputs=[webcam_stream]
516
+ )
517
+
518
+ terminate_btn.click(fn=terminate_distraction_stream, inputs=None, outputs=None)
519
+
520
+ return distraction_live_demo
521
+
522
  def create_video_interface():
523
  video_demo = gr.Interface(
524
  fn=analyze_video,
 
530
  return video_demo
531
 
532
  demo = gr.TabbedInterface(
533
+ [create_video_interface(), create_gaze_interface(), create_distraction_interface(), create_distraction_live_interface()],
534
+ ["Gaze Video Upload", "Gaze & Drowsiness (Live)", "Distraction Video Upload", "Distraction Detection (Live)"],
535
  title="Driver Monitoring System"
536
  )
537
 
 
547
  is_unconscious = False
548
  frame_count_webcam = 0
549
  stop_gaze_processing = False
550
+ stop_distraction_processing = False
551
  demo.launch()
requirements.txt CHANGED
@@ -11,4 +11,6 @@ tensorflow
11
  pygame
12
  twilio
13
  ultralytics==8.3.93
14
- torch==2.6.0
 
 
 
11
  pygame
12
  twilio
13
  ultralytics==8.3.93
14
+ # torch==2.6.0 # Replace with ZeroGPU compatible version, e.g., 2.4.0
15
+ torch==2.4.0 # Example compatible version
16
+ spaces # Add spaces for ZeroGPU