driver / app.py
Guru-25's picture
Update app.py
2784ae0 verified
import gradio as gr
import cv2
import numpy as np
import tempfile
import os
import time
import spaces
from scripts.inference import GazePredictor
from utils.ear_utils import BlinkDetector
from gradio_webrtc import WebRTC
from ultralytics import YOLO
import torch
import json
import requests
# --- Model cache variables ---
distraction_model_cache = None
def smooth_values(history, current_value, window_size=5):
if current_value is not None:
if isinstance(current_value, np.ndarray):
history.append(current_value)
elif isinstance(current_value, (int, float)):
history.append(current_value)
if len(history) > window_size:
history.pop(0)
if not history:
return current_value
if all(isinstance(item, np.ndarray) for item in history):
first_shape = history[0].shape
if all(item.shape == first_shape for item in history):
return np.mean(history, axis=0)
else:
return history[-1] if history else None
elif all(isinstance(item, (int, float)) for item in history):
return np.mean(history)
else:
return history[-1] if history else None
# --- Configure Twilio TURN servers for WebRTC ---
def get_twilio_turn_credentials():
# Replace with your Twilio credentials or set as environment variables
twilio_account_sid = os.environ.get("TWILIO_ACCOUNT_SID", "")
twilio_auth_token = os.environ.get("TWILIO_AUTH_TOKEN", "")
if not twilio_account_sid or not twilio_auth_token:
print("Warning: Twilio credentials not found. Using default RTCConfiguration.")
return None
try:
response = requests.post(
f"https://api.twilio.com/2010-04-01/Accounts/{twilio_account_sid}/Tokens.json",
auth=(twilio_account_sid, twilio_auth_token),
)
data = response.json()
return data["ice_servers"]
except Exception as e:
print(f"Error fetching Twilio TURN credentials: {e}")
return None
# Configure WebRTC
ice_servers = get_twilio_turn_credentials()
if ice_servers:
rtc_configuration = {"iceServers": ice_servers}
else:
rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
# --- Model Paths ---
GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
DISTRACTION_MODEL_PATH = "best.pt"
# --- Global Initializations ---
blink_detector = BlinkDetector()
# Distraction Class Names
distraction_class_names = [
'safe driving', 'drinking', 'eating', 'hair and makeup',
'operating radio', 'talking on phone', 'talking to passenger'
]
# --- Global State Variables for Gaze Webcam ---
gaze_history = []
head_history = []
ear_history = []
stable_gaze_time = 0
stable_head_time = 0
eye_closed_time = 0
blink_count = 0
start_time = 0
is_unconscious = False
frame_count_webcam = 0
stop_gaze_processing = False
# --- Global State Variables for Distraction Webcam ---
stop_distraction_processing = False
# Constants
GAZE_STABILITY_THRESHOLD = 0.5
TIME_THRESHOLD = 15
BLINK_RATE_THRESHOLD = 1
EYE_CLOSURE_THRESHOLD = 10
HEAD_STABILITY_THRESHOLD = 0.05
DISTRACTION_CONF_THRESHOLD = 0.1
@spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming
def analyze_video(input_video):
cap = cv2.VideoCapture(input_video)
local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
local_blink_detector = BlinkDetector()
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4')
os.close(temp_fd)
out = None
video_gaze_history = []
video_head_history = []
video_ear_history = []
video_stable_gaze_time = 0
video_stable_head_time = 0
video_eye_closed_time = 0
video_blink_count = 0
video_start_time = 0
video_is_unconscious = False
video_frame_count = 0
fps = cap.get(cv2.CAP_PROP_FPS) or 30
while True:
ret, frame = cap.read()
if not ret:
break
video_frame_count += 1
current_time_video = video_frame_count / fps
if video_start_time == 0:
video_start_time = current_time_video
head_pose_gaze, gaze_h, gaze_v = local_gaze_predictor.predict_gaze(frame)
current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
smoothed_gaze = smooth_values(video_gaze_history, current_gaze)
ear, left_eye, right_eye, head_pose, left_iris, right_iris = local_blink_detector.detect_blinks(frame)
if ear is None:
cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
smoothed_head = smooth_values(video_head_history, None)
smoothed_ear = smooth_values(video_ear_history, None)
else:
smoothed_head = smooth_values(video_head_history, head_pose)
smoothed_ear = smooth_values(video_ear_history, ear)
if smoothed_ear >= local_blink_detector.EAR_THRESHOLD and left_iris and right_iris:
if all(isinstance(coord, (int, float)) and coord >= 0 for coord in left_iris) and \
all(isinstance(coord, (int, float)) and coord >= 0 for coord in right_iris):
try:
cv2.drawMarker(frame, tuple(map(int, left_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
cv2.drawMarker(frame, tuple(map(int, right_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
except OverflowError:
print(f"Warning: OverflowError drawing iris markers at {left_iris}, {right_iris}")
gaze_text_h = f"Gaze H: {smoothed_gaze[0]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 0 else "Gaze H: N/A"
gaze_text_v = f"Gaze V: {smoothed_gaze[1]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 1 else "Gaze V: N/A"
head_text = f"Head Pose: {smoothed_head:.2f}" if smoothed_head is not None else "Head Pose: N/A"
ear_text = f"EAR: {smoothed_ear:.2f}" if smoothed_ear is not None else "EAR: N/A"
cv2.putText(frame, gaze_text_h, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, gaze_text_v, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, head_text, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, ear_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
if len(video_gaze_history) > 1 and smoothed_gaze is not None and video_gaze_history[-2] is not None:
try:
gaze_diff = np.sqrt(np.sum((smoothed_gaze - video_gaze_history[-2])**2))
if gaze_diff < GAZE_STABILITY_THRESHOLD:
if video_stable_gaze_time == 0:
video_stable_gaze_time = current_time_video
else:
video_stable_gaze_time = 0
except TypeError:
video_stable_gaze_time = 0
else:
video_stable_gaze_time = 0
if len(video_head_history) > 1 and smoothed_head is not None and video_head_history[-2] is not None:
head_diff = abs(smoothed_head - video_head_history[-2])
if head_diff < HEAD_STABILITY_THRESHOLD:
if video_stable_head_time == 0:
video_stable_head_time = current_time_video
else:
video_stable_head_time = 0
else:
video_stable_head_time = 0
if ear is not None and smoothed_ear is not None and smoothed_ear < local_blink_detector.EAR_THRESHOLD:
if video_eye_closed_time == 0:
video_eye_closed_time = current_time_video
elif current_time_video - video_eye_closed_time > EYE_CLOSURE_THRESHOLD:
cv2.putText(frame, "Eyes Closed", (10, 210), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
elif ear is not None:
if video_eye_closed_time > 0 and current_time_video - video_eye_closed_time < 0.5:
video_blink_count += 1
video_eye_closed_time = 0
else:
video_eye_closed_time = 0
elapsed_seconds_video = current_time_video - video_start_time if video_start_time > 0 else 0
elapsed_minutes_video = elapsed_seconds_video / 60
blink_rate = video_blink_count / elapsed_minutes_video if elapsed_minutes_video > 0 else 0
cv2.putText(frame, f"Blink Rate: {blink_rate:.1f}/min", (10, 240), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
unconscious_conditions = [
video_stable_gaze_time > 0 and current_time_video - video_stable_gaze_time > TIME_THRESHOLD,
blink_rate < BLINK_RATE_THRESHOLD and elapsed_minutes_video > 1,
video_eye_closed_time > 0 and current_time_video - video_eye_closed_time > EYE_CLOSURE_THRESHOLD,
video_stable_head_time > 0 and current_time_video - video_stable_head_time > TIME_THRESHOLD
]
if sum(unconscious_conditions) >= 2:
cv2.putText(frame, "Unconscious Detected", (10, 270), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
video_is_unconscious = True
else:
video_is_unconscious = False
if out is None:
h, w = frame.shape[:2]
out = cv2.VideoWriter(temp_path, fourcc, fps, (w, h))
out.write(frame)
cap.release()
if out:
out.release()
return temp_path
@spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming
def analyze_distraction_video(input_video):
cap = cv2.VideoCapture(input_video)
if not cap.isOpened():
print("Error: Could not open video file.")
return None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4')
os.close(temp_fd)
out = None
fps = cap.get(cv2.CAP_PROP_FPS) or 30
global distraction_model_cache
if distraction_model_cache is None:
distraction_model_cache = YOLO(DISTRACTION_MODEL_PATH)
distraction_model_cache.to('cpu')
while True:
ret, frame = cap.read()
if not ret:
break
try:
results = distraction_model_cache(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
display_text = "safe driving"
alarm_action = None
for result in results:
if result.boxes is not None and len(result.boxes) > 0:
boxes = result.boxes.xyxy.cpu().numpy()
scores = result.boxes.conf.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
if len(boxes) > 0:
max_score_idx = scores.argmax()
detected_action_idx = int(classes[max_score_idx])
if 0 <= detected_action_idx < len(distraction_class_names):
detected_action = distraction_class_names[detected_action_idx]
confidence = scores[max_score_idx]
display_text = f"{detected_action}: {confidence:.2f}"
if detected_action != 'safe driving':
alarm_action = detected_action
else:
print(f"Warning: Detected class index {detected_action_idx} out of bounds.")
display_text = "Unknown Detection"
if alarm_action:
print(f"ALARM: Unsafe behavior detected - {alarm_action}!")
cv2.putText(frame, f"ALARM: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255)
cv2.putText(frame, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
if out is None:
h, w = frame.shape[:2]
out = cv2.VideoWriter(temp_path, fourcc, fps, (w, h))
out.write(frame)
except Exception as e:
print(f"Error processing distraction frame in video: {e}")
if out is None:
h, w = frame.shape[:2]
out = cv2.VideoWriter(temp_path, fourcc, fps, (w, h))
cv2.putText(frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
out.write(frame)
cap.release()
if out:
out.release()
return temp_path
@spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming
def process_distraction_frame(frame):
global stop_distraction_processing
global distraction_model_cache
if stop_distraction_processing:
return np.zeros((480, 640, 3), dtype=np.uint8)
if frame is None:
return np.zeros((480, 640, 3), dtype=np.uint8)
if distraction_model_cache is None:
distraction_model_cache = YOLO(DISTRACTION_MODEL_PATH)
distraction_model_cache.to('cpu')
try:
# Run distraction detection model
results = distraction_model_cache(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
display_text = "safe driving"
alarm_action = None
for result in results:
if result.boxes is not None and len(result.boxes) > 0:
boxes = result.boxes.xyxy.cpu().numpy()
scores = result.boxes.conf.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
if len(boxes) > 0:
# Draw bounding boxes
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box)
cls_id = int(classes[i])
confidence = scores[i]
if 0 <= cls_id < len(distraction_class_names):
action = distraction_class_names[cls_id]
color = (0, 255, 0) if action == "safe driving" else (0, 0, 255)
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, f"{action} {confidence:.2f}", (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Select highest confidence detection for status
if i == scores.argmax():
detected_action = action
confidence_score = confidence
display_text = f"{detected_action}: {confidence_score:.2f}"
if detected_action != 'safe driving':
alarm_action = detected_action
else:
print(f"Warning: Detected class index {cls_id} out of bounds.")
display_text = "Unknown Detection"
if alarm_action:
cv2.putText(frame, f"ALERT: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
# Always show current detection status
text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255)
cv2.putText(frame, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
# Convert BGR to RGB for Gradio display
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
except Exception as e:
print(f"Error processing frame for distraction detection: {e}")
error_frame = np.zeros((480, 640, 3), dtype=np.uint8)
if not error_frame.flags.writeable:
error_frame = error_frame.copy()
cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
return error_frame
def terminate_gaze_stream():
global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
print("Gaze Termination signal received. Stopping processing and resetting state.")
stop_gaze_processing = True
gaze_history = []
head_history = []
ear_history = []
stable_gaze_time = 0
stable_head_time = 0
eye_closed_time = 0
blink_count = 0
start_time = 0
is_unconscious = False
frame_count_webcam = 0
return "Gaze Processing Terminated. State Reset."
def terminate_distraction_stream():
global stop_distraction_processing
print("Distraction Termination signal received. Stopping processing.")
stop_distraction_processing = True
return "Distraction Processing Terminated."
@spaces.GPU(duration=60) # Extended duration to 60 seconds for longer streaming
def process_gaze_frame(frame):
global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
if stop_gaze_processing:
return np.zeros((480, 640, 3), dtype=np.uint8)
if frame is None:
return np.zeros((480, 640, 3), dtype=np.uint8)
frame_count_webcam += 1
current_time = time.time()
if start_time == 0:
start_time = current_time
local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
try:
head_pose_gaze, gaze_h, gaze_v = local_gaze_predictor.predict_gaze(frame)
current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
smoothed_gaze = smooth_values(gaze_history, current_gaze)
ear, left_eye, right_eye, head_pose, left_iris, right_iris = blink_detector.detect_blinks(frame)
if ear is None:
cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
smoothed_head = smooth_values(head_history, None)
smoothed_ear = smooth_values(ear_history, None)
else:
smoothed_head = smooth_values(head_history, head_pose)
smoothed_ear = smooth_values(ear_history, ear)
if smoothed_ear >= blink_detector.EAR_THRESHOLD and left_iris and right_iris:
if all(isinstance(coord, (int, float)) and coord >= 0 for coord in left_iris) and \
all(isinstance(coord, (int, float)) and coord >= 0 for coord in right_iris):
try:
cv2.drawMarker(frame, tuple(map(int, left_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
cv2.drawMarker(frame, tuple(map(int, right_iris)), (0, 255, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
except OverflowError:
print(f"Warning: OverflowError drawing iris markers at {left_iris}, {right_iris}")
gaze_text_h = f"Gaze H: {smoothed_gaze[0]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 0 else "Gaze H: N/A"
gaze_text_v = f"Gaze V: {smoothed_gaze[1]:.2f}" if smoothed_gaze is not None and len(smoothed_gaze) > 1 else "Gaze V: N/A"
head_text = f"Head Pose: {smoothed_head:.2f}" if smoothed_head is not None else "Head Pose: N/A"
ear_text = f"EAR: {smoothed_ear:.2f}" if smoothed_ear is not None else "EAR: N/A"
cv2.putText(frame, gaze_text_h, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, gaze_text_v, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, head_text, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(frame, ear_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
if len(gaze_history) > 1 and smoothed_gaze is not None and gaze_history[-2] is not None:
try:
gaze_diff = np.sqrt(np.sum((smoothed_gaze - gaze_history[-2])**2))
if gaze_diff < GAZE_STABILITY_THRESHOLD:
if stable_gaze_time == 0:
stable_gaze_time = current_time
else:
stable_gaze_time = 0
except TypeError:
stable_gaze_time = 0
else:
stable_gaze_time = 0
if len(head_history) > 1 and smoothed_head is not None and head_history[-2] is not None:
head_diff = abs(smoothed_head - head_history[-2])
if head_diff < HEAD_STABILITY_THRESHOLD:
if stable_head_time == 0:
stable_head_time = current_time
else:
stable_head_time = 0
else:
stable_head_time = 0
if ear is not None and smoothed_ear is not None and smoothed_ear < blink_detector.EAR_THRESHOLD:
if eye_closed_time == 0:
eye_closed_time = current_time
elif current_time - eye_closed_time > EYE_CLOSURE_THRESHOLD:
cv2.putText(frame, "Eyes Closed", (10, 210), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
elif ear is not None:
if eye_closed_time > 0 and current_time - eye_closed_time < 0.5:
blink_count += 1
eye_closed_time = 0
else:
eye_closed_time = 0
elapsed_seconds = current_time - start_time if start_time > 0 else 0
elapsed_minutes = elapsed_seconds / 60
blink_rate = blink_count / elapsed_minutes if elapsed_minutes > 0 else 0
cv2.putText(frame, f"Blink Rate: {blink_rate:.1f}/min", (10, 240), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
unconscious_conditions = [
stable_gaze_time > 0 and current_time - stable_gaze_time > TIME_THRESHOLD,
blink_rate < BLINK_RATE_THRESHOLD and elapsed_minutes > 1,
eye_closed_time > 0 and current_time - eye_closed_time > EYE_CLOSURE_THRESHOLD,
stable_head_time > 0 and current_time - stable_head_time > TIME_THRESHOLD
]
if sum(unconscious_conditions) >= 2:
cv2.putText(frame, "Unconscious Detected", (10, 270), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
is_unconscious = True
else:
is_unconscious = False
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
except Exception as e:
print(f"Error processing frame: {e}")
error_frame = np.zeros((480, 640, 3), dtype=np.uint8)
if not error_frame.flags.writeable:
error_frame = error_frame.copy()
cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
return error_frame
def create_gaze_interface():
with gr.Blocks() as gaze_demo:
gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
with gr.Row():
webcam_stream = WebRTC(label="Webcam Stream", rtc_configuration=rtc_configuration)
with gr.Row():
terminate_btn = gr.Button("Terminate Process")
webcam_stream.stream(
fn=process_gaze_frame,
inputs=[webcam_stream],
outputs=[webcam_stream]
)
terminate_btn.click(fn=terminate_gaze_stream, inputs=None, outputs=None)
return gaze_demo
def create_distraction_interface():
with gr.Blocks() as distraction_demo:
gr.Markdown("## Real-time Distraction Detection")
with gr.Row():
webcam_stream = WebRTC(label="Webcam Stream", rtc_configuration=rtc_configuration)
with gr.Row():
terminate_btn = gr.Button("Terminate Process")
webcam_stream.stream(
fn=process_distraction_frame,
inputs=[webcam_stream],
outputs=[webcam_stream]
)
terminate_btn.click(fn=terminate_distraction_stream, inputs=None, outputs=None)
return distraction_demo
def create_video_interface():
video_demo = gr.Interface(
fn=analyze_video,
inputs=gr.Video(),
outputs=gr.Video(),
title="Gaze Detection",
description="Analyze gaze in realtime."
)
return video_demo
demo = gr.TabbedInterface(
[create_video_interface(), create_distraction_interface()],
["Gaze Detection", "Distraction Detection (Live)"],
title="DriveAware"
)
if __name__ == "__main__":
gaze_history = []
head_history = []
ear_history = []
stable_gaze_time = 0
stable_head_time = 0
eye_closed_time = 0
blink_count = 0
start_time = 0
is_unconscious = False
frame_count_webcam = 0
stop_gaze_processing = False
stop_distraction_processing = False
demo.launch()