import gradio as gr import cv2 import numpy as np import tempfile import os import time import spaces from scripts.inference import GazePredictor from utils.ear_utils import BlinkDetector from gradio_webrtc import WebRTC from ultralytics import YOLO import torch import json import requests # --- Model cache variables --- distraction_model_cache = None def smooth_values(history, current_value, window_size=5): if current_value is not None: if isinstance(current_value, np.ndarray): history.append(current_value) elif isinstance(current_value, (int, float)): history.append(current_value) if len(history) > window_size: history.pop(0) if not history: return current_value if all(isinstance(item, np.ndarray) for item in history): first_shape = history[0].shape if all(item.shape == first_shape for item in history): return np.mean(history, axis=0) else: return history[-1] if history else None elif all(isinstance(item, (int, float)) for item in history): return np.mean(history) else: return history[-1] if history else None # --- Configure Twilio TURN servers for WebRTC --- def get_twilio_turn_credentials(): # Replace with your Twilio credentials or set as environment variables twilio_account_sid = os.environ.get("TWILIO_ACCOUNT_SID", "") twilio_auth_token = os.environ.get("TWILIO_AUTH_TOKEN", "") if not twilio_account_sid or not twilio_auth_token: print("Warning: Twilio credentials not found. Using default RTCConfiguration.") return None try: response = requests.post( f"https://api.twilio.com/2010-04-01/Accounts/{twilio_account_sid}/Tokens.json", auth=(twilio_account_sid, twilio_auth_token), ) data = response.json() return data["ice_servers"] except Exception as e: print(f"Error fetching Twilio TURN credentials: {e}") return None # Configure WebRTC ice_servers = get_twilio_turn_credentials() if ice_servers: rtc_configuration = {"iceServers": ice_servers} else: rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} # --- Model Paths --- GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth") DISTRACTION_MODEL_PATH = "best.pt" # --- Global Initializations --- blink_detector = BlinkDetector() # Distraction Class Names distraction_class_names = [ 'safe driving', 'drinking', 'eating', 'hair and makeup', 'operating radio', 'talking on phone', 'talking to passenger' ] # --- Global State Variables for Gaze Webcam --- gaze_history = [] head_history = [] ear_history = [] stable_gaze_time = 0 stable_head_time = 0 eye_closed_time = 0 blink_count = 0 start_time = 0 is_unconscious = False frame_count_webcam = 0 stop_gaze_processing = False # --- Global State Variables for Distraction Webcam --- stop_distraction_processing = False # Constants GAZE_STABILITY_THRESHOLD = 0.5 TIME_THRESHOLD = 15 BLINK_RATE_THRESHOLD = 1 EYE_CLOSURE_THRESHOLD = 10 HEAD_STABILITY_THRESHOLD = 0.05 DISTRACTION_CONF_THRESHOLD = 0.1 @spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming def analyze_video(input_video): cap = cv2.VideoCapture(input_video) local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH) local_blink_detector = BlinkDetector() fourcc = cv2.VideoWriter_fourcc(*'mp4v') temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4') os.close(temp_fd) out = None video_gaze_history = [] video_head_history = [] video_ear_history = [] video_stable_gaze_time = 0 video_stable_head_time = 0 video_eye_closed_time = 0 video_blink_count = 0 video_start_time = 0 video_is_unconscious = False video_frame_count = 0 fps = cap.get(cv2.CAP_PROP_FPS) or 30 while True: ret, frame = cap.read() if not ret: break video_frame_count += 1 current_time_video = video_frame_count / fps if video_start_time == 0: video_start_time = current_time_video head_pose_gaze, gaze_h, gaze_v = local_gaze_predictor.predict_gaze(frame) current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None smoothed_gaze = smooth_values(video_gaze_history, current_gaze) ear, left_eye, right_eye, head_pose, left_iris, right_iris = local_blink_detector.detect_blinks(frame) if ear is None: cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) smoothed_head = smooth_values(video_head_history, None) smoothed_ear = smooth_values(video_ear_history, None) else: smoothed_head = smooth_values(video_head_history, head_pose) smoothed_ear = smooth_values(video_ear_history, ear) if smoothed_ear >= local_blink_detector.EAR_THRESHOLD and left_iris and right_iris: if all(isinstance(coord, (int, float)) and coord >= 0 for coord in left_iris) and \ all(isinstance(coord, (int, float)) and coord >= 0 for coord in right_iris): try: cv2.drawMarker(frame, tuple(map(int, left_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2) cv2.drawMarker(frame, tuple(map(int, right_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2) except OverflowError: print(f"Warning: OverflowError drawing iris markers at {left_iris}, {right_iris}") gaze_text_h = f"Gaze H: {smoothed_gaze[0]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 0 else "Gaze H: N/A" gaze_text_v = f"Gaze V: {smoothed_gaze[1]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 1 else "Gaze V: N/A" head_text = f"Head Pose: {smoothed_head:.2f}" if smoothed_head is not None else "Head Pose: N/A" ear_text = f"EAR: {smoothed_ear:.2f}" if smoothed_ear is not None else "EAR: N/A" cv2.putText(frame, gaze_text_h, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(frame, gaze_text_v, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(frame, head_text, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(frame, ear_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) if len(video_gaze_history) > 1 and smoothed_gaze is not None and video_gaze_history[-2] is not None: try: gaze_diff = np.sqrt(np.sum((smoothed_gaze - video_gaze_history[-2])**2)) if gaze_diff < GAZE_STABILITY_THRESHOLD: if video_stable_gaze_time == 0: video_stable_gaze_time = current_time_video else: video_stable_gaze_time = 0 except TypeError: video_stable_gaze_time = 0 else: video_stable_gaze_time = 0 if len(video_head_history) > 1 and smoothed_head is not None and video_head_history[-2] is not None: head_diff = abs(smoothed_head - video_head_history[-2]) if head_diff < HEAD_STABILITY_THRESHOLD: if video_stable_head_time == 0: video_stable_head_time = current_time_video else: video_stable_head_time = 0 else: video_stable_head_time = 0 if ear is not None and smoothed_ear is not None and smoothed_ear < local_blink_detector.EAR_THRESHOLD: if video_eye_closed_time == 0: video_eye_closed_time = current_time_video elif current_time_video - video_eye_closed_time > EYE_CLOSURE_THRESHOLD: cv2.putText(frame, "Eyes Closed", (10, 210), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) elif ear is not None: if video_eye_closed_time > 0 and current_time_video - video_eye_closed_time < 0.5: video_blink_count += 1 video_eye_closed_time = 0 else: video_eye_closed_time = 0 elapsed_seconds_video = current_time_video - video_start_time if video_start_time > 0 else 0 elapsed_minutes_video = elapsed_seconds_video / 60 blink_rate = video_blink_count / elapsed_minutes_video if elapsed_minutes_video > 0 else 0 cv2.putText(frame, f"Blink Rate: {blink_rate:.1f}/min", (10, 240), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) unconscious_conditions = [ video_stable_gaze_time > 0 and current_time_video - video_stable_gaze_time > TIME_THRESHOLD, blink_rate < BLINK_RATE_THRESHOLD and elapsed_minutes_video > 1, video_eye_closed_time > 0 and current_time_video - video_eye_closed_time > EYE_CLOSURE_THRESHOLD, video_stable_head_time > 0 and current_time_video - video_stable_head_time > TIME_THRESHOLD ] if sum(unconscious_conditions) >= 2: cv2.putText(frame, "Unconscious Detected", (10, 270), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) video_is_unconscious = True else: video_is_unconscious = False if out is None: h, w = frame.shape[:2] out = cv2.VideoWriter(temp_path, fourcc, fps, (w, h)) out.write(frame) cap.release() if out: out.release() return temp_path @spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming def analyze_distraction_video(input_video): cap = cv2.VideoCapture(input_video) if not cap.isOpened(): print("Error: Could not open video file.") return None fourcc = cv2.VideoWriter_fourcc(*'mp4v') temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4') os.close(temp_fd) out = None fps = cap.get(cv2.CAP_PROP_FPS) or 30 global distraction_model_cache if distraction_model_cache is None: distraction_model_cache = YOLO(DISTRACTION_MODEL_PATH) distraction_model_cache.to('cpu') while True: ret, frame = cap.read() if not ret: break try: results = distraction_model_cache(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False) display_text = "safe driving" alarm_action = None for result in results: if result.boxes is not None and len(result.boxes) > 0: boxes = result.boxes.xyxy.cpu().numpy() scores = result.boxes.conf.cpu().numpy() classes = result.boxes.cls.cpu().numpy() if len(boxes) > 0: max_score_idx = scores.argmax() detected_action_idx = int(classes[max_score_idx]) if 0 <= detected_action_idx < len(distraction_class_names): detected_action = distraction_class_names[detected_action_idx] confidence = scores[max_score_idx] display_text = f"{detected_action}: {confidence:.2f}" if detected_action != 'safe driving': alarm_action = detected_action else: print(f"Warning: Detected class index {detected_action_idx} out of bounds.") display_text = "Unknown Detection" if alarm_action: print(f"ALARM: Unsafe behavior detected - {alarm_action}!") cv2.putText(frame, f"ALARM: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255) cv2.putText(frame, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2) if out is None: h, w = frame.shape[:2] out = cv2.VideoWriter(temp_path, fourcc, fps, (w, h)) out.write(frame) except Exception as e: print(f"Error processing distraction frame in video: {e}") if out is None: h, w = frame.shape[:2] out = cv2.VideoWriter(temp_path, fourcc, fps, (w, h)) cv2.putText(frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) out.write(frame) cap.release() if out: out.release() return temp_path @spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming def process_distraction_frame(frame): global stop_distraction_processing global distraction_model_cache if stop_distraction_processing: return np.zeros((480, 640, 3), dtype=np.uint8) if frame is None: return np.zeros((480, 640, 3), dtype=np.uint8) if distraction_model_cache is None: distraction_model_cache = YOLO(DISTRACTION_MODEL_PATH) distraction_model_cache.to('cpu') try: # Run distraction detection model results = distraction_model_cache(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False) display_text = "safe driving" alarm_action = None for result in results: if result.boxes is not None and len(result.boxes) > 0: boxes = result.boxes.xyxy.cpu().numpy() scores = result.boxes.conf.cpu().numpy() classes = result.boxes.cls.cpu().numpy() if len(boxes) > 0: # Draw bounding boxes for i, box in enumerate(boxes): x1, y1, x2, y2 = map(int, box) cls_id = int(classes[i]) confidence = scores[i] if 0 <= cls_id < len(distraction_class_names): action = distraction_class_names[cls_id] color = (0, 255, 0) if action == "safe driving" else (0, 0, 255) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, f"{action} {confidence:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) # Select highest confidence detection for status if i == scores.argmax(): detected_action = action confidence_score = confidence display_text = f"{detected_action}: {confidence_score:.2f}" if detected_action != 'safe driving': alarm_action = detected_action else: print(f"Warning: Detected class index {cls_id} out of bounds.") display_text = "Unknown Detection" if alarm_action: cv2.putText(frame, f"ALERT: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) # Always show current detection status text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255) cv2.putText(frame, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2) # Convert BGR to RGB for Gradio display frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return frame_rgb except Exception as e: print(f"Error processing frame for distraction detection: {e}") error_frame = np.zeros((480, 640, 3), dtype=np.uint8) if not error_frame.flags.writeable: error_frame = error_frame.copy() cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) return error_frame def terminate_gaze_stream(): global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing print("Gaze Termination signal received. Stopping processing and resetting state.") stop_gaze_processing = True gaze_history = [] head_history = [] ear_history = [] stable_gaze_time = 0 stable_head_time = 0 eye_closed_time = 0 blink_count = 0 start_time = 0 is_unconscious = False frame_count_webcam = 0 return "Gaze Processing Terminated. State Reset." def terminate_distraction_stream(): global stop_distraction_processing print("Distraction Termination signal received. Stopping processing.") stop_distraction_processing = True return "Distraction Processing Terminated." @spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming def process_gaze_frame(frame): global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing if stop_gaze_processing: return np.zeros((480, 640, 3), dtype=np.uint8) if frame is None: return np.zeros((480, 640, 3), dtype=np.uint8) frame_count_webcam += 1 current_time = time.time() if start_time == 0: start_time = current_time local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH) try: head_pose_gaze, gaze_h, gaze_v = local_gaze_predictor.predict_gaze(frame) current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None smoothed_gaze = smooth_values(gaze_history, current_gaze) ear, left_eye, right_eye, head_pose, left_iris, right_iris = blink_detector.detect_blinks(frame) if ear is None: cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2) smoothed_head = smooth_values(head_history, None) smoothed_ear = smooth_values(ear_history, None) else: smoothed_head = smooth_values(head_history, head_pose) smoothed_ear = smooth_values(ear_history, ear) if smoothed_ear >= blink_detector.EAR_THRESHOLD and left_iris and right_iris: if all(isinstance(coord, (int, float)) and coord >= 0 for coord in left_iris) and \ all(isinstance(coord, (int, float)) and coord >= 0 for coord in right_iris): try: cv2.drawMarker(frame, tuple(map(int, left_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2) cv2.drawMarker(frame, tuple(map(int, right_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2) except OverflowError: print(f"Warning: OverflowError drawing iris markers at {left_iris}, {right_iris}") gaze_text_h = f"Gaze H: {smoothed_gaze[0]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 0 else "Gaze H: N/A" gaze_text_v = f"Gaze V: {smoothed_gaze[1]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 1 else "Gaze V: N/A" head_text = f"Head Pose: {smoothed_head:.2f}" if smoothed_head is not None else "Head Pose: N/A" ear_text = f"EAR: {smoothed_ear:.2f}" if smoothed_ear is not None else "EAR: N/A" cv2.putText(frame, gaze_text_h, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(frame, gaze_text_v, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(frame, head_text, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(frame, ear_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) if len(gaze_history) > 1 and smoothed_gaze is not None and gaze_history[-2] is not None: try: gaze_diff = np.sqrt(np.sum((smoothed_gaze - gaze_history[-2])**2)) if gaze_diff < GAZE_STABILITY_THRESHOLD: if stable_gaze_time == 0: stable_gaze_time = current_time else: stable_gaze_time = 0 except TypeError: stable_gaze_time = 0 else: stable_gaze_time = 0 if len(head_history) > 1 and smoothed_head is not None and head_history[-2] is not None: head_diff = abs(smoothed_head - head_history[-2]) if head_diff < HEAD_STABILITY_THRESHOLD: if stable_head_time == 0: stable_head_time = current_time else: stable_head_time = 0 else: stable_head_time = 0 if ear is not None and smoothed_ear is not None and smoothed_ear < blink_detector.EAR_THRESHOLD: if eye_closed_time == 0: eye_closed_time = current_time elif current_time - eye_closed_time > EYE_CLOSURE_THRESHOLD: cv2.putText(frame, "Eyes Closed", (10, 210), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) elif ear is not None: if eye_closed_time > 0 and current_time - eye_closed_time < 0.5: blink_count += 1 eye_closed_time = 0 else: eye_closed_time = 0 elapsed_seconds = current_time - start_time if start_time > 0 else 0 elapsed_minutes = elapsed_seconds / 60 blink_rate = blink_count / elapsed_minutes if elapsed_minutes > 0 else 0 cv2.putText(frame, f"Blink Rate: {blink_rate:.1f}/min", (10, 240), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) unconscious_conditions = [ stable_gaze_time > 0 and current_time - stable_gaze_time > TIME_THRESHOLD, blink_rate < BLINK_RATE_THRESHOLD and elapsed_minutes > 1, eye_closed_time > 0 and current_time - eye_closed_time > EYE_CLOSURE_THRESHOLD, stable_head_time > 0 and current_time - stable_head_time > TIME_THRESHOLD ] if sum(unconscious_conditions) >= 2: cv2.putText(frame, "Unconscious Detected", (10, 270), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) is_unconscious = True else: is_unconscious = False frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return frame_rgb except Exception as e: print(f"Error processing frame: {e}") error_frame = np.zeros((480, 640, 3), dtype=np.uint8) if not error_frame.flags.writeable: error_frame = error_frame.copy() cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) return error_frame def create_gaze_interface(): with gr.Blocks() as gaze_demo: gr.Markdown("## Real-time Gaze & Drowsiness Tracking") with gr.Row(): webcam_stream = WebRTC(label="Webcam Stream", rtc_configuration=rtc_configuration) with gr.Row(): terminate_btn = gr.Button("Terminate Process") webcam_stream.stream( fn=process_gaze_frame, inputs=[webcam_stream], outputs=[webcam_stream] ) terminate_btn.click(fn=terminate_gaze_stream, inputs=None, outputs=None) return gaze_demo def create_distraction_interface(): with gr.Blocks() as distraction_demo: gr.Markdown("## Real-time Distraction Detection") with gr.Row(): webcam_stream = WebRTC(label="Webcam Stream", rtc_configuration=rtc_configuration) with gr.Row(): terminate_btn = gr.Button("Terminate Process") webcam_stream.stream( fn=process_distraction_frame, inputs=[webcam_stream], outputs=[webcam_stream] ) terminate_btn.click(fn=terminate_distraction_stream, inputs=None, outputs=None) return distraction_demo def create_video_interface(): video_demo = gr.Interface( fn=analyze_video, inputs=gr.Video(), outputs=gr.Video(), title="Gaze Detection", description="Analyze gaze in realtime." ) return video_demo demo = gr.TabbedInterface( [create_video_interface(), create_distraction_interface()], ["Gaze Detection", "Distraction Detection (Live)"], title="DriveAware" ) if __name__ == "__main__": gaze_history = [] head_history = [] ear_history = [] stable_gaze_time = 0 stable_head_time = 0 eye_closed_time = 0 blink_count = 0 start_time = 0 is_unconscious = False frame_count_webcam = 0 stop_gaze_processing = False stop_distraction_processing = False demo.launch()