Guru-25 commited on
Commit
d343b30
·
verified ·
1 Parent(s): 325bf36
app.py CHANGED
@@ -69,13 +69,8 @@ GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
69
  DISTRACTION_MODEL_PATH = "best.pt"
70
 
71
  # --- Global Initializations ---
72
- gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
73
  blink_detector = BlinkDetector()
74
 
75
- # Load Distraction Model
76
- distraction_model = YOLO(DISTRACTION_MODEL_PATH)
77
- distraction_model.to('cpu')
78
-
79
  # Distraction Class Names
80
  distraction_class_names = [
81
  'safe driving', 'drinking', 'eating', 'hair and makeup',
@@ -106,6 +101,7 @@ EYE_CLOSURE_THRESHOLD = 10
106
  HEAD_STABILITY_THRESHOLD = 0.05
107
  DISTRACTION_CONF_THRESHOLD = 0.1
108
 
 
109
  def analyze_video(input_video):
110
  cap = cv2.VideoCapture(input_video)
111
  local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
@@ -247,13 +243,16 @@ def analyze_distraction_video(input_video):
247
 
248
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
249
 
 
 
 
250
  while True:
251
  ret, frame = cap.read()
252
  if not ret:
253
  break
254
 
255
  try:
256
- results = distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
257
 
258
  display_text = "safe driving"
259
  alarm_action = None
@@ -312,9 +311,12 @@ def process_distraction_frame(frame):
312
  if frame is None:
313
  return np.zeros((480, 640, 3), dtype=np.uint8)
314
 
 
 
 
315
  try:
316
  # Run distraction detection model
317
- results = distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
318
 
319
  display_text = "safe driving"
320
  alarm_action = None
@@ -410,8 +412,10 @@ def process_gaze_frame(frame):
410
  if start_time == 0:
411
  start_time = current_time
412
 
 
 
413
  try:
414
- head_pose_gaze, gaze_h, gaze_v = gaze_predictor.predict_gaze(frame)
415
  current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
416
  smoothed_gaze = smooth_values(gaze_history, current_gaze)
417
 
 
69
  DISTRACTION_MODEL_PATH = "best.pt"
70
 
71
  # --- Global Initializations ---
 
72
  blink_detector = BlinkDetector()
73
 
 
 
 
 
74
  # Distraction Class Names
75
  distraction_class_names = [
76
  'safe driving', 'drinking', 'eating', 'hair and makeup',
 
101
  HEAD_STABILITY_THRESHOLD = 0.05
102
  DISTRACTION_CONF_THRESHOLD = 0.1
103
 
104
+ @spaces.GPU(duration=30) # Set duration to 30 seconds for real-time processing
105
  def analyze_video(input_video):
106
  cap = cv2.VideoCapture(input_video)
107
  local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
 
243
 
244
  fps = cap.get(cv2.CAP_PROP_FPS) or 30
245
 
246
+ local_distraction_model = YOLO(DISTRACTION_MODEL_PATH)
247
+ local_distraction_model.to('cpu')
248
+
249
  while True:
250
  ret, frame = cap.read()
251
  if not ret:
252
  break
253
 
254
  try:
255
+ results = local_distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
256
 
257
  display_text = "safe driving"
258
  alarm_action = None
 
311
  if frame is None:
312
  return np.zeros((480, 640, 3), dtype=np.uint8)
313
 
314
+ local_distraction_model = YOLO(DISTRACTION_MODEL_PATH)
315
+ local_distraction_model.to('cpu')
316
+
317
  try:
318
  # Run distraction detection model
319
+ results = local_distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
320
 
321
  display_text = "safe driving"
322
  alarm_action = None
 
412
  if start_time == 0:
413
  start_time = current_time
414
 
415
+ local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
416
+
417
  try:
418
+ head_pose_gaze, gaze_h, gaze_v = local_gaze_predictor.predict_gaze(frame)
419
  current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
420
  smoothed_gaze = smooth_values(gaze_history, current_gaze)
421
 
scripts/__pycache__/inference.cpython-312.pyc CHANGED
Binary files a/scripts/__pycache__/inference.cpython-312.pyc and b/scripts/__pycache__/inference.cpython-312.pyc differ
 
scripts/inference.py CHANGED
@@ -17,13 +17,29 @@ class GazeEstimationModel(torch.nn.Module):
17
 
18
  class GazePredictor:
19
  def __init__(self, model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  # Initialize the custom model
23
  self.model = GazeEstimationModel()
24
 
25
  # Load the state dictionary
26
- state_dict = torch.load(model_path, map_location=self.device)
27
 
28
  # Check if state_dict has 'backbone.' prefix and strip it if necessary
29
  new_state_dict = {}
@@ -42,14 +58,11 @@ class GazePredictor:
42
  # Move to device and set to evaluation mode
43
  self.model.to(self.device)
44
  self.model.eval()
45
-
46
- # Define preprocessing transform
47
- self.transform = transforms.Compose([
48
- transforms.ToTensor(),
49
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
- ])
51
 
52
  def predict_gaze(self, frame):
 
 
 
53
  preprocessed = preprocess_frame(frame)
54
  preprocessed = preprocessed[0]
55
  preprocessed = self.transform(preprocessed).float().unsqueeze(0)
 
17
 
18
  class GazePredictor:
19
  def __init__(self, model_path):
20
+ # Initialize without moving to device - we'll do that during prediction
21
+ self.model_path = model_path
22
+ self.model = None
23
+ self.device = None
24
+
25
+ # Define transform outside of initialization
26
+ self.transform = transforms.Compose([
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
+ ])
30
+
31
+ def _initialize_model(self):
32
+ # Only initialize model when needed (inside ZeroGPU function)
33
+ if self.model is not None:
34
+ return
35
+
36
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
 
38
  # Initialize the custom model
39
  self.model = GazeEstimationModel()
40
 
41
  # Load the state dictionary
42
+ state_dict = torch.load(self.model_path, map_location=self.device)
43
 
44
  # Check if state_dict has 'backbone.' prefix and strip it if necessary
45
  new_state_dict = {}
 
58
  # Move to device and set to evaluation mode
59
  self.model.to(self.device)
60
  self.model.eval()
 
 
 
 
 
 
61
 
62
  def predict_gaze(self, frame):
63
+ # Initialize model if not already done
64
+ self._initialize_model()
65
+
66
  preprocessed = preprocess_frame(frame)
67
  preprocessed = preprocessed[0]
68
  preprocessed = self.transform(preprocessed).float().unsqueeze(0)